ideprado commited on
Commit
e52e8c8
·
1 Parent(s): 5260f7d

Fix loading problem

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -9,6 +9,7 @@ from pikigen import PikigenPipeline
9
  # Trick required because it is not a native diffusers model
10
  from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
11
  LOADABLE_CLASSES.setdefault("pikigen", {}).setdefault("DiT", []).extend(["save_pretrained", "from_pretrained"])
 
12
  ALL_IMPORTABLE_CLASSES.setdefault("DiT", []).extend(["save_pretrained", "from_pretrained"])
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,7 +21,7 @@ else:
20
  torch_dtype = torch.float32
21
 
22
  pipe = PikigenPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
23
- pipe.enable_model_cpu_offload() # For less memory consumption
24
  pipe.vae.enable_slicing()
25
  pipe.vae.enable_tiling()
26
 
 
9
  # Trick required because it is not a native diffusers model
10
  from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
11
  LOADABLE_CLASSES.setdefault("pikigen", {}).setdefault("DiT", []).extend(["save_pretrained", "from_pretrained"])
12
+ LOADABLE_CLASSES["pikigen.model"] = LOADABLE_CLASSES["pikigen"]
13
  ALL_IMPORTABLE_CLASSES.setdefault("DiT", []).extend(["save_pretrained", "from_pretrained"])
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  torch_dtype = torch.float32
22
 
23
  pipe = PikigenPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
24
+ # pipe.enable_model_cpu_offload() # For less memory consumption
25
  pipe.vae.enable_slicing()
26
  pipe.vae.enable_tiling()
27