jadechoghari
commited on
Commit
•
b99cc11
1
Parent(s):
7a073d1
Update generate.py
Browse files- generate.py +4 -4
generate.py
CHANGED
@@ -12,8 +12,8 @@ from .utils import prepare_control, load_latent, load_video, prepare_depth, save
|
|
12 |
from .pnp_utils import register_time, register_attention_control, register_conv_control
|
13 |
|
14 |
# will cause an issue
|
15 |
-
from . import vidtome
|
16 |
-
|
17 |
# suppress partial model loading warning
|
18 |
logging.set_verbosity_error()
|
19 |
|
@@ -95,7 +95,7 @@ class Generator(nn.Module):
|
|
95 |
self.pipe.load_lora_weights(**gene_config.lora)
|
96 |
|
97 |
def activate_vidtome(self):
|
98 |
-
|
99 |
seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
|
100 |
|
101 |
@torch.no_grad()
|
@@ -234,7 +234,7 @@ class Generator(nn.Module):
|
|
234 |
def post_iter(self, x, t):
|
235 |
if self.merge_global:
|
236 |
# Reset global tokens
|
237 |
-
|
238 |
|
239 |
@torch.no_grad()
|
240 |
def pred_noise(self, x, cond, t, batch_idx=None):
|
|
|
12 |
from .pnp_utils import register_time, register_attention_control, register_conv_control
|
13 |
|
14 |
# will cause an issue
|
15 |
+
# from . import vidtome
|
16 |
+
from .vidtome import update_patch, update_patch
|
17 |
# suppress partial model loading warning
|
18 |
logging.set_verbosity_error()
|
19 |
|
|
|
95 |
self.pipe.load_lora_weights(**gene_config.lora)
|
96 |
|
97 |
def activate_vidtome(self):
|
98 |
+
apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
|
99 |
seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
|
100 |
|
101 |
@torch.no_grad()
|
|
|
234 |
def post_iter(self, x, t):
|
235 |
if self.merge_global:
|
236 |
# Reset global tokens
|
237 |
+
update_patch(self.pipe, global_tokens = None)
|
238 |
|
239 |
@torch.no_grad()
|
240 |
def pred_noise(self, x, cond, t, batch_idx=None):
|