Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7ad3113
1
Parent(s):
3c5de4a
modify
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ class GradioArgs:
|
|
50 |
if self.ratio is None:
|
51 |
self.ratio = [0.68, 0.88]
|
52 |
|
53 |
-
|
54 |
def prune_model(pipe, hookers):
|
55 |
# remove parameters in attention blocks
|
56 |
cross_attn_hooker = hookers[0]
|
@@ -91,17 +91,18 @@ def prune_model(pipe, hookers):
|
|
91 |
ffn_hook.clear_hooks()
|
92 |
return pipe
|
93 |
|
|
|
94 |
def binary_mask_eval(args):
|
95 |
# load sdxl model
|
96 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
97 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
98 |
-
)
|
99 |
|
100 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
101 |
mask_pipe, hookers = create_pipeline(
|
102 |
pipe,
|
103 |
args.model,
|
104 |
-
|
105 |
torch_dtype,
|
106 |
args.ckpt,
|
107 |
binary=args.binary,
|
@@ -131,7 +132,7 @@ def binary_mask_eval(args):
|
|
131 |
# reload the original model
|
132 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
133 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
134 |
-
)
|
135 |
|
136 |
# get model param summary
|
137 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
|
50 |
if self.ratio is None:
|
51 |
self.ratio = [0.68, 0.88]
|
52 |
|
53 |
+
@spaces.GPU
|
54 |
def prune_model(pipe, hookers):
|
55 |
# remove parameters in attention blocks
|
56 |
cross_attn_hooker = hookers[0]
|
|
|
91 |
ffn_hook.clear_hooks()
|
92 |
return pipe
|
93 |
|
94 |
+
@spaces.GPU
|
95 |
def binary_mask_eval(args):
|
96 |
# load sdxl model
|
97 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
98 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
99 |
+
).to(device)
|
100 |
|
101 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
102 |
mask_pipe, hookers = create_pipeline(
|
103 |
pipe,
|
104 |
args.model,
|
105 |
+
device,
|
106 |
torch_dtype,
|
107 |
args.ckpt,
|
108 |
binary=args.binary,
|
|
|
132 |
# reload the original model
|
133 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
134 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
135 |
+
).to(device)
|
136 |
|
137 |
# get model param summary
|
138 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|