File size: 8,715 Bytes
49f5a92
82d824b
eb3568a
49f5a92
82d824b
 
 
 
 
 
 
 
 
 
0bb8ff5
82d824b
5e20c42
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e20c42
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e20c42
82d824b
 
 
 
5e20c42
82d824b
 
 
 
 
5e20c42
82d824b
 
 
 
 
 
 
 
 
 
49f5a92
58dde5b
 
 
 
 
 
 
 
 
 
 
49f5a92
82d824b
 
49f5a92
82d824b
 
 
5e20c42
49f5a92
82d824b
 
 
 
 
 
5e20c42
eb3568a
82d824b
5e20c42
 
82d824b
5e20c42
82d824b
5e20c42
82d824b
 
 
 
 
 
 
 
49f5a92
 
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58dde5b
82d824b
49f5a92
82d824b
 
 
 
5e20c42
58dde5b
 
 
 
 
 
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
 
 
82d824b
eb3568a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import gradio as gr
from dataclasses import dataclass
import spaces
import torch
from tqdm import tqdm

from src.utils import (
    create_pipeline,
    calculate_mask_sparsity,
    ffn_linear_layer_pruning,
    linear_layer_pruning,
)
from diffusers import StableDiffusionXLPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"


def get_model_param_summary(model, verbose=False):
    params_dict = dict()
    overall_params = 0
    for name, params in model.named_parameters():
        num_params = params.numel()
        overall_params += num_params
        if verbose:
            print(f"GPU Memory Requirement for {name}: {params} MiB")
        params_dict.update({name: num_params})
    params_dict.update({"overall": overall_params})
    return params_dict


@dataclass
class GradioArgs:
    ckpt: str = "./mask/ff.pt"
    seed: list = None
    prompt: str = None
    mix_precision: str = "bf16"
    num_intervention_steps: int = 50
    model: str = "sdxl"
    binary: bool = False
    masking: str = "binary"
    scope: str = "global"
    ratio: list = None
    width: int = None
    height: int = None
    epsilon: float = 0.0
    lambda_threshold: float = 0.001

    def __post_init__(self):
        if self.seed is None:
            self.seed = [44]
        if self.ratio is None:
            self.ratio = [0.68, 0.88]


def prune_model(pipe, hookers):
    # remove parameters in attention blocks
    cross_attn_hooker = hookers[0]
    for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
        if getattr(pipe, "unet", None):
            module = pipe.unet.get_submodule(name)
        else:
            module = pipe.transformer.get_submodule(name)
        lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
        assert module.heads == lamb.shape[0]
        module = linear_layer_pruning(module, lamb)

        parent_module_name, child_name = name.rsplit(".", 1)
        if getattr(pipe, "unet", None):
            parent_module = pipe.unet.get_submodule(parent_module_name)
        else:
            parent_module = pipe.transformer.get_submodule(parent_module_name)
        setattr(parent_module, child_name, module)

    # remove parameters in ffn blocks
    ffn_hook = hookers[1]
    for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
        if getattr(pipe, "unet", None):
            module = pipe.unet.get_submodule(name)
        else:
            module = pipe.transformer.get_submodule(name)
        lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
        module = ffn_linear_layer_pruning(module, lamb)

        parent_module_name, child_name = name.rsplit(".", 1)
        if getattr(pipe, "unet", None):
            parent_module = pipe.unet.get_submodule(parent_module_name)
        else:
            parent_module = pipe.transformer.get_submodule(parent_module_name)
        setattr(parent_module, child_name, module)

    cross_attn_hooker.clear_hooks()
    ffn_hook.clear_hooks()
    return pipe


def binary_mask_eval(args):
    # load sdxl model
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
    ).to("cpu")

    torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
    mask_pipe, hookers = create_pipeline(
        pipe,
        args.model,
        "cpu",
        torch_dtype,
        args.ckpt,
        binary=args.binary,
        lambda_threshold=args.lambda_threshold,
        epsilon=args.epsilon,
        masking=args.masking,
        return_hooker=True,
        scope=args.scope,
        ratio=args.ratio,
    )

    # # Print mask sparsity info
    # threshold = None if args.binary else args.lambda_threshold
    # threshold = None if args.scope is not None else threshold
    # name = ["ff", "attn"]
    # for n, hooker in zip(name, hookers):
    #     total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
    #     print(f"model: {args.model}, {n} masking: {args.masking}")
    #     print(
    #         f"total num heads: {total_num_heads},"
    #         + f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
    #     )

    # Prune the model
    pruned_pipe = prune_model(mask_pipe, hookers)

    # reload the original model
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
    ).to("cpu")

    # get model param summary
    print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
    print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
    print("prune complete")
    return pipe, pruned_pipe


@spaces.GPU
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
    pipe.to("cuda")
    pruned_pipe.to("cuda")
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


def on_prune_click(prompt, seed, steps):
    args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
    pipe, pruned_pipe = binary_mask_eval(args)
    return pipe, pruned_pipe, [("Model Initialized", "green")]


def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
    original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
    return original_image, ecodiff_image


def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
        with gr.Row():
            gr.Markdown(
                """
                # 🚧 Under Construction 🚧
                This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
                The current pruned model checkpoint is not optimal and does not provide the best performance.
                
                **Note: Please first initialize the model before generating images. This may take up to 5 minutes as it is on CPU.**
                """
            )
        with gr.Row():
            model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
            pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
            status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
            prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
        with gr.Row():
            gr.Markdown(
                """
                **Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.**
                """
            )
        with gr.Row():
            prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
            seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
            steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
            generate_btn = gr.Button("Generate Images")
        gr.Examples(
            examples=[
                "A clock tower floating in a sea of clouds",
                "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                "An astronaut riding a green horse",
                "A delicious ceviche cheesecake slice",
            ],
            inputs=[prompt],
        )
        with gr.Row():
            original_output = gr.Image(label="Original Output")
            ecodiff_output = gr.Image(label="EcoDiff Output")

        pipe_state = gr.State(None)
        pruned_pipe_state = gr.State(None)
        prompt.submit(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )
        prune_btn.click(
            fn=on_prune_click,
            inputs=[prompt, seed, steps],
            outputs=[pipe_state, pruned_pipe_state, status_label],
        )
        generate_btn.click(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch()