Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
48cfefd
1
Parent(s):
a3181aa
date banner
Browse files
README.md
CHANGED
@@ -9,8 +9,6 @@ app_file: app.py
|
|
9 |
pinned: true
|
10 |
python_version: 3.10
|
11 |
short_description: Diffusion Model Compression
|
12 |
-
hf_oauth: true
|
13 |
-
hf_oauth_expiration_minutes: 43200
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: true
|
10 |
python_version: 3.10
|
11 |
short_description: Diffusion Model Compression
|
|
|
|
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -38,19 +38,15 @@ def binary_mask_eval(args, model):
|
|
38 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
39 |
).to("cpu")
|
40 |
pruned_pipe.unet = torch.load(
|
41 |
-
hf_hub_download(
|
42 |
-
"zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"
|
43 |
-
),
|
44 |
map_location="cpu",
|
45 |
)
|
46 |
elif model == "flux":
|
47 |
-
pruned_pipe = FluxPipeline.from_pretrained(
|
48 |
-
"
|
49 |
-
)
|
50 |
pruned_pipe.transformer = torch.load(
|
51 |
-
hf_hub_download(
|
52 |
-
"zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"
|
53 |
-
),
|
54 |
map_location="cpu",
|
55 |
)
|
56 |
|
@@ -60,9 +56,7 @@ def binary_mask_eval(args, model):
|
|
60 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
61 |
).to("cpu")
|
62 |
elif model == "flux":
|
63 |
-
pipe = FluxPipeline.from_pretrained(
|
64 |
-
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
65 |
-
).to("cpu")
|
66 |
|
67 |
print("prune complete")
|
68 |
return pipe, pruned_pipe
|
@@ -74,13 +68,9 @@ def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
|
74 |
pruned_pipe.to("cuda")
|
75 |
# Run the model and return images directly
|
76 |
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
77 |
-
original_image = pipe(
|
78 |
-
prompt=prompt, generator=g_cpu, num_inference_steps=steps
|
79 |
-
).images[0]
|
80 |
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
81 |
-
ecodiff_image = pruned_pipe(
|
82 |
-
prompt=prompt, generator=g_cpu, num_inference_steps=steps
|
83 |
-
).images[0]
|
84 |
return original_image, ecodiff_image
|
85 |
|
86 |
|
@@ -91,15 +81,29 @@ def on_prune_click(prompt, seed, steps, model):
|
|
91 |
|
92 |
|
93 |
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
|
94 |
-
original_image, ecodiff_image = generate_images(
|
95 |
-
prompt, seed, steps, pipe, pruned_pipe
|
96 |
-
)
|
97 |
return original_image, ecodiff_image
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def create_demo():
|
101 |
with gr.Blocks() as demo:
|
102 |
-
gr.Markdown(
|
103 |
with gr.Row():
|
104 |
gr.Markdown(
|
105 |
"""
|
@@ -107,18 +111,10 @@ def create_demo():
|
|
107 |
"""
|
108 |
)
|
109 |
with gr.Row():
|
110 |
-
model_choice = gr.Radio(
|
111 |
-
|
112 |
-
)
|
113 |
-
|
114 |
-
"20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2
|
115 |
-
)
|
116 |
-
status_label = gr.HighlightedText(
|
117 |
-
label="Model Status", value=[("Model Not Initialized", "red")], scale=1
|
118 |
-
)
|
119 |
-
prune_btn = gr.Button(
|
120 |
-
"Initialize Original and Pruned Models", variant="primary", scale=1
|
121 |
-
)
|
122 |
|
123 |
with gr.Row():
|
124 |
gr.Markdown(
|
@@ -182,4 +178,4 @@ def create_demo():
|
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
demo = create_demo()
|
185 |
-
demo.launch(share=True)
|
|
|
38 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
39 |
).to("cpu")
|
40 |
pruned_pipe.unet = torch.load(
|
41 |
+
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
|
|
|
|
|
42 |
map_location="cpu",
|
43 |
)
|
44 |
elif model == "flux":
|
45 |
+
pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
|
46 |
+
"cpu"
|
47 |
+
)
|
48 |
pruned_pipe.transformer = torch.load(
|
49 |
+
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
|
|
|
|
|
50 |
map_location="cpu",
|
51 |
)
|
52 |
|
|
|
56 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
57 |
).to("cpu")
|
58 |
elif model == "flux":
|
59 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
|
|
|
|
|
60 |
|
61 |
print("prune complete")
|
62 |
return pipe, pruned_pipe
|
|
|
68 |
pruned_pipe.to("cuda")
|
69 |
# Run the model and return images directly
|
70 |
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
71 |
+
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
|
|
|
|
72 |
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
73 |
+
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
|
|
|
|
74 |
return original_image, ecodiff_image
|
75 |
|
76 |
|
|
|
81 |
|
82 |
|
83 |
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
|
84 |
+
original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
|
|
|
|
|
85 |
return original_image, ecodiff_image
|
86 |
|
87 |
|
88 |
+
header = """
|
89 |
+
# 🌱 Text-to-Image Generation with EcoDiff Pruned Models
|
90 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
91 |
+
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
|
92 |
+
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
|
93 |
+
<a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
|
94 |
+
</div>
|
95 |
+
|
96 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
97 |
+
For **⚡faster⚡** DEMO on one model only, please visit
|
98 |
+
<a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-SD-XL"><img alt="Static Badge" src="https://img.shields.io/badge/SDXL-fedcba.svg"></a>
|
99 |
+
<a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-FLUX-Schnell"><img alt="Static Badge" src="https://img.shields.io/badge/FLUX-fgdfba"></a>
|
100 |
+
</div>
|
101 |
+
"""
|
102 |
+
|
103 |
+
|
104 |
def create_demo():
|
105 |
with gr.Blocks() as demo:
|
106 |
+
gr.Markdown(header)
|
107 |
with gr.Row():
|
108 |
gr.Markdown(
|
109 |
"""
|
|
|
111 |
"""
|
112 |
)
|
113 |
with gr.Row():
|
114 |
+
model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2)
|
115 |
+
pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2)
|
116 |
+
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
117 |
+
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
with gr.Row():
|
120 |
gr.Markdown(
|
|
|
178 |
|
179 |
if __name__ == "__main__":
|
180 |
demo = create_demo()
|
181 |
+
demo.launch(share=True)
|