Spaces:
Runtime error
Runtime error
CSH-1220
commited on
Commit
Β·
24363dc
1
Parent(s):
1834911
Update how we load pre-trained weights
Browse files- app.py +0 -15
- audio_encoder/AudioMAE.py +6 -1
- pipeline/morph_pipeline_successed_ver1.py +7 -3
app.py
CHANGED
@@ -3,21 +3,6 @@ import torch
|
|
3 |
import torchaudio
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
model_path = hf_hub_download(
|
8 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
9 |
-
filename="pretrained.pth",
|
10 |
-
local_dir="./",
|
11 |
-
local_dir_use_symlinks=False
|
12 |
-
)
|
13 |
-
|
14 |
-
model_path = hf_hub_download(
|
15 |
-
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
16 |
-
filename="pytorch_model.bin",
|
17 |
-
local_dir="./",
|
18 |
-
local_dir_use_symlinks=False
|
19 |
-
)
|
20 |
-
|
21 |
from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
|
22 |
# Initialize AudioLDM2 Pipeline
|
23 |
pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
|
|
3 |
import torchaudio
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
|
7 |
# Initialize AudioLDM2 Pipeline
|
8 |
pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
|
audio_encoder/AudioMAE.py
CHANGED
@@ -12,6 +12,7 @@ import librosa.display
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import numpy as np
|
14 |
import torchaudio
|
|
|
15 |
|
16 |
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
|
17 |
class Vanilla_AudioMAE(nn.Module):
|
@@ -25,7 +26,11 @@ class Vanilla_AudioMAE(nn.Module):
|
|
25 |
in_chans=1, audio_exp=True, img_size=(1024, 128)
|
26 |
)
|
27 |
|
28 |
-
checkpoint_path = 'pretrained.pth'
|
|
|
|
|
|
|
|
|
29 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
30 |
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
31 |
|
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import numpy as np
|
14 |
import torchaudio
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
|
18 |
class Vanilla_AudioMAE(nn.Module):
|
|
|
26 |
in_chans=1, audio_exp=True, img_size=(1024, 128)
|
27 |
)
|
28 |
|
29 |
+
# checkpoint_path = 'pretrained.pth'
|
30 |
+
checkpoint_path = hf_hub_download(
|
31 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
32 |
+
filename="pretrained.pth"
|
33 |
+
)
|
34 |
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
35 |
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
36 |
|
pipeline/morph_pipeline_successed_ver1.py
CHANGED
@@ -49,8 +49,7 @@ if is_librosa_available():
|
|
49 |
import librosa
|
50 |
import warnings
|
51 |
import matplotlib.pyplot as plt
|
52 |
-
|
53 |
-
|
54 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
55 |
|
56 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -91,7 +90,12 @@ for name in unet.attn_processors.keys():
|
|
91 |
else:
|
92 |
attn_procs[name] = AttnProcessor2_0()
|
93 |
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
for name, processor in attn_procs.items():
|
96 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
97 |
weight_name_v = name + ".to_v_ip.weight"
|
|
|
49 |
import librosa
|
50 |
import warnings
|
51 |
import matplotlib.pyplot as plt
|
52 |
+
from huggingface_hub import hf_hub_download
|
|
|
53 |
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
54 |
|
55 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
90 |
else:
|
91 |
attn_procs[name] = AttnProcessor2_0()
|
92 |
|
93 |
+
adapter_weight = hf_hub_download(
|
94 |
+
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
|
95 |
+
filename="pytorch_model.bin",
|
96 |
+
)
|
97 |
+
|
98 |
+
state_dict = torch.load(adapter_weight, map_location=DEVICE)
|
99 |
for name, processor in attn_procs.items():
|
100 |
if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
|
101 |
weight_name_v = name + ".to_v_ip.weight"
|