Spaces:
Build error
Build error
Commit
·
5e20c42
1
Parent(s):
7ad3113
modify
Browse files
app.py
CHANGED
|
@@ -14,6 +14,7 @@ from diffusers import StableDiffusionXLPipeline
|
|
| 14 |
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
|
|
|
| 17 |
def get_model_param_summary(model, verbose=False):
|
| 18 |
params_dict = dict()
|
| 19 |
overall_params = 0
|
|
@@ -50,7 +51,7 @@ class GradioArgs:
|
|
| 50 |
if self.ratio is None:
|
| 51 |
self.ratio = [0.68, 0.88]
|
| 52 |
|
| 53 |
-
|
| 54 |
def prune_model(pipe, hookers):
|
| 55 |
# remove parameters in attention blocks
|
| 56 |
cross_attn_hooker = hookers[0]
|
|
@@ -91,18 +92,18 @@ def prune_model(pipe, hookers):
|
|
| 91 |
ffn_hook.clear_hooks()
|
| 92 |
return pipe
|
| 93 |
|
| 94 |
-
|
| 95 |
def binary_mask_eval(args):
|
| 96 |
# load sdxl model
|
| 97 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 98 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 99 |
-
).to(
|
| 100 |
|
| 101 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
| 102 |
mask_pipe, hookers = create_pipeline(
|
| 103 |
pipe,
|
| 104 |
args.model,
|
| 105 |
-
|
| 106 |
torch_dtype,
|
| 107 |
args.ckpt,
|
| 108 |
binary=args.binary,
|
|
@@ -132,7 +133,7 @@ def binary_mask_eval(args):
|
|
| 132 |
# reload the original model
|
| 133 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 134 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 135 |
-
).to(
|
| 136 |
|
| 137 |
# get model param summary
|
| 138 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
@@ -140,12 +141,15 @@ def binary_mask_eval(args):
|
|
| 140 |
print("prune complete")
|
| 141 |
return pipe, pruned_pipe
|
| 142 |
|
|
|
|
| 143 |
@spaces.GPU
|
| 144 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
|
|
|
|
|
|
| 145 |
# Run the model and return images directly
|
| 146 |
-
g_cpu = torch.Generator(
|
| 147 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 148 |
-
g_cpu = torch.Generator(
|
| 149 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 150 |
return original_image, ecodiff_image
|
| 151 |
|
|
@@ -177,8 +181,8 @@ def create_demo():
|
|
| 177 |
with gr.Row():
|
| 178 |
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
|
| 179 |
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
|
| 180 |
-
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
| 181 |
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
|
|
|
| 182 |
with gr.Row():
|
| 183 |
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
|
| 184 |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
|
|
|
|
| 14 |
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
+
|
| 18 |
def get_model_param_summary(model, verbose=False):
|
| 19 |
params_dict = dict()
|
| 20 |
overall_params = 0
|
|
|
|
| 51 |
if self.ratio is None:
|
| 52 |
self.ratio = [0.68, 0.88]
|
| 53 |
|
| 54 |
+
|
| 55 |
def prune_model(pipe, hookers):
|
| 56 |
# remove parameters in attention blocks
|
| 57 |
cross_attn_hooker = hookers[0]
|
|
|
|
| 92 |
ffn_hook.clear_hooks()
|
| 93 |
return pipe
|
| 94 |
|
| 95 |
+
|
| 96 |
def binary_mask_eval(args):
|
| 97 |
# load sdxl model
|
| 98 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 99 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 100 |
+
).to("cpu")
|
| 101 |
|
| 102 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
| 103 |
mask_pipe, hookers = create_pipeline(
|
| 104 |
pipe,
|
| 105 |
args.model,
|
| 106 |
+
"cpu",
|
| 107 |
torch_dtype,
|
| 108 |
args.ckpt,
|
| 109 |
binary=args.binary,
|
|
|
|
| 133 |
# reload the original model
|
| 134 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 135 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 136 |
+
).to("cpu")
|
| 137 |
|
| 138 |
# get model param summary
|
| 139 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
|
|
| 141 |
print("prune complete")
|
| 142 |
return pipe, pruned_pipe
|
| 143 |
|
| 144 |
+
|
| 145 |
@spaces.GPU
|
| 146 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
| 147 |
+
pipe.to("cuda")
|
| 148 |
+
pruned_pipe.to("cuda")
|
| 149 |
# Run the model and return images directly
|
| 150 |
+
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
| 151 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 152 |
+
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
| 153 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 154 |
return original_image, ecodiff_image
|
| 155 |
|
|
|
|
| 181 |
with gr.Row():
|
| 182 |
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
|
| 183 |
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
|
|
|
|
| 184 |
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
| 185 |
+
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
| 186 |
with gr.Row():
|
| 187 |
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
|
| 188 |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
|