Spaces:
Sleeping
Sleeping
burakcanbiner
commited on
Update pnp.py
Browse files
pnp.py
CHANGED
@@ -67,11 +67,17 @@ class PNP(nn.Module):
|
|
67 |
).to("cuda")
|
68 |
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
#unet.to(self.device);
|
76 |
|
77 |
#pipe.unet = unet.to(self.device);
|
@@ -102,7 +108,7 @@ class PNP(nn.Module):
|
|
102 |
# self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path)
|
103 |
self.text_encoder = self.text_encoder.cuda()
|
104 |
|
105 |
-
|
106 |
|
107 |
self.audio_projector_ckpt_path = audio_projector_ckpt_path
|
108 |
self.adapter_ckpt_path = adapter_ckpt_path
|
|
|
67 |
).to("cuda")
|
68 |
|
69 |
|
70 |
+
|
71 |
+
audio_projector_path = "ckpts/audio_projector_landscape.pth"
|
72 |
+
adapter_ckpt_path = "ckpts/landscape.pt"
|
73 |
+
#self.pnp.set_audio_projector(gate_dict_path, audio_projector_path)
|
74 |
+
|
75 |
+
gate_dict = torch.load(adapter_ckpt_path)
|
76 |
|
77 |
+
for name, param in self.unet.named_parameters():
|
78 |
+
if "adapter" in name:
|
79 |
+
param.data = gate_dict[name]
|
80 |
+
|
81 |
#unet.to(self.device);
|
82 |
|
83 |
#pipe.unet = unet.to(self.device);
|
|
|
108 |
# self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path)
|
109 |
self.text_encoder = self.text_encoder.cuda()
|
110 |
|
111 |
+
self.audio_projector.load_state_dict(torch.load(audio_projector_path))
|
112 |
|
113 |
self.audio_projector_ckpt_path = audio_projector_ckpt_path
|
114 |
self.adapter_ckpt_path = adapter_ckpt_path
|