zhangyang-0123 commited on
Commit
48cfefd
·
1 Parent(s): a3181aa

date banner

Browse files
Files changed (2) hide show
  1. README.md +0 -2
  2. app.py +31 -35
README.md CHANGED
@@ -9,8 +9,6 @@ app_file: app.py
9
  pinned: true
10
  python_version: 3.10
11
  short_description: Diffusion Model Compression
12
- hf_oauth: true
13
- hf_oauth_expiration_minutes: 43200
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: true
10
  python_version: 3.10
11
  short_description: Diffusion Model Compression
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -38,19 +38,15 @@ def binary_mask_eval(args, model):
38
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
39
  ).to("cpu")
40
  pruned_pipe.unet = torch.load(
41
- hf_hub_download(
42
- "zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"
43
- ),
44
  map_location="cpu",
45
  )
46
  elif model == "flux":
47
- pruned_pipe = FluxPipeline.from_pretrained(
48
- "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
49
- ).to("cpu")
50
  pruned_pipe.transformer = torch.load(
51
- hf_hub_download(
52
- "zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"
53
- ),
54
  map_location="cpu",
55
  )
56
 
@@ -60,9 +56,7 @@ def binary_mask_eval(args, model):
60
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
61
  ).to("cpu")
62
  elif model == "flux":
63
- pipe = FluxPipeline.from_pretrained(
64
- "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
65
- ).to("cpu")
66
 
67
  print("prune complete")
68
  return pipe, pruned_pipe
@@ -74,13 +68,9 @@ def generate_images(prompt, seed, steps, pipe, pruned_pipe):
74
  pruned_pipe.to("cuda")
75
  # Run the model and return images directly
76
  g_cpu = torch.Generator("cuda").manual_seed(seed)
77
- original_image = pipe(
78
- prompt=prompt, generator=g_cpu, num_inference_steps=steps
79
- ).images[0]
80
  g_cpu = torch.Generator("cuda").manual_seed(seed)
81
- ecodiff_image = pruned_pipe(
82
- prompt=prompt, generator=g_cpu, num_inference_steps=steps
83
- ).images[0]
84
  return original_image, ecodiff_image
85
 
86
 
@@ -91,15 +81,29 @@ def on_prune_click(prompt, seed, steps, model):
91
 
92
 
93
  def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
94
- original_image, ecodiff_image = generate_images(
95
- prompt, seed, steps, pipe, pruned_pipe
96
- )
97
  return original_image, ecodiff_image
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def create_demo():
101
  with gr.Blocks() as demo:
102
- gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
103
  with gr.Row():
104
  gr.Markdown(
105
  """
@@ -107,18 +111,10 @@ def create_demo():
107
  """
108
  )
109
  with gr.Row():
110
- model_choice = gr.Radio(
111
- choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2
112
- )
113
- pruning_ratio = gr.Text(
114
- "20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2
115
- )
116
- status_label = gr.HighlightedText(
117
- label="Model Status", value=[("Model Not Initialized", "red")], scale=1
118
- )
119
- prune_btn = gr.Button(
120
- "Initialize Original and Pruned Models", variant="primary", scale=1
121
- )
122
 
123
  with gr.Row():
124
  gr.Markdown(
@@ -182,4 +178,4 @@ def create_demo():
182
 
183
  if __name__ == "__main__":
184
  demo = create_demo()
185
- demo.launch(share=True)
 
38
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
39
  ).to("cpu")
40
  pruned_pipe.unet = torch.load(
41
+ hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
 
 
42
  map_location="cpu",
43
  )
44
  elif model == "flux":
45
+ pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
46
+ "cpu"
47
+ )
48
  pruned_pipe.transformer = torch.load(
49
+ hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
 
 
50
  map_location="cpu",
51
  )
52
 
 
56
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
57
  ).to("cpu")
58
  elif model == "flux":
59
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
 
 
60
 
61
  print("prune complete")
62
  return pipe, pruned_pipe
 
68
  pruned_pipe.to("cuda")
69
  # Run the model and return images directly
70
  g_cpu = torch.Generator("cuda").manual_seed(seed)
71
+ original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
 
 
72
  g_cpu = torch.Generator("cuda").manual_seed(seed)
73
+ ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
 
 
74
  return original_image, ecodiff_image
75
 
76
 
 
81
 
82
 
83
  def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
84
+ original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
 
 
85
  return original_image, ecodiff_image
86
 
87
 
88
+ header = """
89
+ # 🌱 Text-to-Image Generation with EcoDiff Pruned Models
90
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
91
+ <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
92
+ <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
93
+ <a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
94
+ </div>
95
+
96
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
97
+ For **⚡faster⚡** DEMO on one model only, please visit
98
+ <a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-SD-XL"><img alt="Static Badge" src="https://img.shields.io/badge/SDXL-fedcba.svg"></a>
99
+ <a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-FLUX-Schnell"><img alt="Static Badge" src="https://img.shields.io/badge/FLUX-fgdfba"></a>
100
+ </div>
101
+ """
102
+
103
+
104
  def create_demo():
105
  with gr.Blocks() as demo:
106
+ gr.Markdown(header)
107
  with gr.Row():
108
  gr.Markdown(
109
  """
 
111
  """
112
  )
113
  with gr.Row():
114
+ model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2)
115
+ pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2)
116
+ status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
117
+ prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
 
 
 
 
 
 
 
 
118
 
119
  with gr.Row():
120
  gr.Markdown(
 
178
 
179
  if __name__ == "__main__":
180
  demo = create_demo()
181
+ demo.launch(share=True)