CSH-1220 commited on
Commit
24363dc
Β·
1 Parent(s): 1834911

Update how we load pre-trained weights

Browse files
app.py CHANGED
@@ -3,21 +3,6 @@ import torch
3
  import torchaudio
4
  import numpy as np
5
  import gradio as gr
6
- from huggingface_hub import hf_hub_download
7
- model_path = hf_hub_download(
8
- repo_id="DennisHung/Pre-trained_AudioMAE_weights",
9
- filename="pretrained.pth",
10
- local_dir="./",
11
- local_dir_use_symlinks=False
12
- )
13
-
14
- model_path = hf_hub_download(
15
- repo_id="DennisHung/Pre-trained_AudioMAE_weights",
16
- filename="pytorch_model.bin",
17
- local_dir="./",
18
- local_dir_use_symlinks=False
19
- )
20
-
21
  from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
22
  # Initialize AudioLDM2 Pipeline
23
  pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
 
3
  import torchaudio
4
  import numpy as np
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
7
  # Initialize AudioLDM2 Pipeline
8
  pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
audio_encoder/AudioMAE.py CHANGED
@@ -12,6 +12,7 @@ import librosa.display
12
  import matplotlib.pyplot as plt
13
  import numpy as np
14
  import torchaudio
 
15
 
16
  # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
17
  class Vanilla_AudioMAE(nn.Module):
@@ -25,7 +26,11 @@ class Vanilla_AudioMAE(nn.Module):
25
  in_chans=1, audio_exp=True, img_size=(1024, 128)
26
  )
27
 
28
- checkpoint_path = 'pretrained.pth'
 
 
 
 
29
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
30
  msg = model.load_state_dict(checkpoint['model'], strict=False)
31
 
 
12
  import matplotlib.pyplot as plt
13
  import numpy as np
14
  import torchaudio
15
+ from huggingface_hub import hf_hub_download
16
 
17
  # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
18
  class Vanilla_AudioMAE(nn.Module):
 
26
  in_chans=1, audio_exp=True, img_size=(1024, 128)
27
  )
28
 
29
+ # checkpoint_path = 'pretrained.pth'
30
+ checkpoint_path = hf_hub_download(
31
+ repo_id="DennisHung/Pre-trained_AudioMAE_weights",
32
+ filename="pretrained.pth"
33
+ )
34
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
35
  msg = model.load_state_dict(checkpoint['model'], strict=False)
36
 
pipeline/morph_pipeline_successed_ver1.py CHANGED
@@ -49,8 +49,7 @@ if is_librosa_available():
49
  import librosa
50
  import warnings
51
  import matplotlib.pyplot as plt
52
-
53
-
54
  from .pipeline_audioldm2 import AudioLDM2Pipeline
55
 
56
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -91,7 +90,12 @@ for name in unet.attn_processors.keys():
91
  else:
92
  attn_procs[name] = AttnProcessor2_0()
93
 
94
- state_dict = torch.load('pytorch_model.bin', map_location=DEVICE)
 
 
 
 
 
95
  for name, processor in attn_procs.items():
96
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
97
  weight_name_v = name + ".to_v_ip.weight"
 
49
  import librosa
50
  import warnings
51
  import matplotlib.pyplot as plt
52
+ from huggingface_hub import hf_hub_download
 
53
  from .pipeline_audioldm2 import AudioLDM2Pipeline
54
 
55
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
90
  else:
91
  attn_procs[name] = AttnProcessor2_0()
92
 
93
+ adapter_weight = hf_hub_download(
94
+ repo_id="DennisHung/Pre-trained_AudioMAE_weights",
95
+ filename="pytorch_model.bin",
96
+ )
97
+
98
+ state_dict = torch.load(adapter_weight, map_location=DEVICE)
99
  for name, processor in attn_procs.items():
100
  if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
101
  weight_name_v = name + ".to_v_ip.weight"