jadechoghari commited on
Commit
b99cc11
1 Parent(s): 7a073d1

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +4 -4
generate.py CHANGED
@@ -12,8 +12,8 @@ from .utils import prepare_control, load_latent, load_video, prepare_depth, save
12
  from .pnp_utils import register_time, register_attention_control, register_conv_control
13
 
14
  # will cause an issue
15
- from . import vidtome
16
-
17
  # suppress partial model loading warning
18
  logging.set_verbosity_error()
19
 
@@ -95,7 +95,7 @@ class Generator(nn.Module):
95
  self.pipe.load_lora_weights(**gene_config.lora)
96
 
97
  def activate_vidtome(self):
98
- vidtome.apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
99
  seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
100
 
101
  @torch.no_grad()
@@ -234,7 +234,7 @@ class Generator(nn.Module):
234
  def post_iter(self, x, t):
235
  if self.merge_global:
236
  # Reset global tokens
237
- vidtome.update_patch(self.pipe, global_tokens = None)
238
 
239
  @torch.no_grad()
240
  def pred_noise(self, x, cond, t, batch_idx=None):
 
12
  from .pnp_utils import register_time, register_attention_control, register_conv_control
13
 
14
  # will cause an issue
15
+ # from . import vidtome
16
+ from .vidtome import update_patch, update_patch
17
  # suppress partial model loading warning
18
  logging.set_verbosity_error()
19
 
 
95
  self.pipe.load_lora_weights(**gene_config.lora)
96
 
97
  def activate_vidtome(self):
98
+ apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
99
  seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
100
 
101
  @torch.no_grad()
 
234
  def post_iter(self, x, t):
235
  if self.merge_global:
236
  # Reset global tokens
237
+ update_patch(self.pipe, global_tokens = None)
238
 
239
  @torch.no_grad()
240
  def pred_noise(self, x, cond, t, batch_idx=None):