tolgacangoz commited on
Commit
1456055
·
verified ·
1 Parent(s): 484ba85

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +10 -3
matryoshka.py CHANGED
@@ -20,6 +20,7 @@
20
 
21
 
22
  import inspect
 
23
  import math
24
  from dataclasses import dataclass
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -3864,6 +3865,9 @@ class MatryoshkaPipeline(
3864
  else:
3865
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3866
 
 
 
 
3867
  def encode_prompt(
3868
  self,
3869
  prompt,
@@ -4645,9 +4649,12 @@ class MatryoshkaPipeline(
4645
  image = latents
4646
 
4647
  if self.scheduler.scales is not None:
4648
- for i in range(len(image)):
4649
- image[i] = image[i] * self.scheduler.scales[i]
4650
- image[i] = self.image_processor.postprocess(image[i], output_type=output_type)[0]
 
 
 
4651
  else:
4652
  image = self.image_processor.postprocess(image, output_type=output_type)
4653
 
 
20
 
21
 
22
  import inspect
23
+ import gc
24
  import math
25
  from dataclasses import dataclass
26
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
3865
  else:
3866
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3867
 
3868
+ gc.collect()
3869
+ torch.cuda.empty_cache()
3870
+
3871
  def encode_prompt(
3872
  self,
3873
  prompt,
 
4649
  image = latents
4650
 
4651
  if self.scheduler.scales is not None:
4652
+ scales = [
4653
+ image[i].size(-1) / image[-1].size(-1)
4654
+ for i in range(len(image))
4655
+ ]
4656
+ for i, (img, scale) in enumerate(zip(image, scales)):
4657
+ image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
4658
  else:
4659
  image = self.image_processor.postprocess(image, output_type=output_type)
4660