zhangyang-0123 commited on
Commit
7ad3113
·
1 Parent(s): 3c5de4a
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -50,7 +50,7 @@ class GradioArgs:
50
  if self.ratio is None:
51
  self.ratio = [0.68, 0.88]
52
 
53
-
54
  def prune_model(pipe, hookers):
55
  # remove parameters in attention blocks
56
  cross_attn_hooker = hookers[0]
@@ -91,17 +91,18 @@ def prune_model(pipe, hookers):
91
  ffn_hook.clear_hooks()
92
  return pipe
93
 
 
94
  def binary_mask_eval(args):
95
  # load sdxl model
96
  pipe = StableDiffusionXLPipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
98
- )#.to(device)
99
 
100
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
101
  mask_pipe, hookers = create_pipeline(
102
  pipe,
103
  args.model,
104
- "cpu",
105
  torch_dtype,
106
  args.ckpt,
107
  binary=args.binary,
@@ -131,7 +132,7 @@ def binary_mask_eval(args):
131
  # reload the original model
132
  pipe = StableDiffusionXLPipeline.from_pretrained(
133
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
134
- )#.to(device)
135
 
136
  # get model param summary
137
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
 
50
  if self.ratio is None:
51
  self.ratio = [0.68, 0.88]
52
 
53
+ @spaces.GPU
54
  def prune_model(pipe, hookers):
55
  # remove parameters in attention blocks
56
  cross_attn_hooker = hookers[0]
 
91
  ffn_hook.clear_hooks()
92
  return pipe
93
 
94
+ @spaces.GPU
95
  def binary_mask_eval(args):
96
  # load sdxl model
97
  pipe = StableDiffusionXLPipeline.from_pretrained(
98
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
99
+ ).to(device)
100
 
101
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
102
  mask_pipe, hookers = create_pipeline(
103
  pipe,
104
  args.model,
105
+ device,
106
  torch_dtype,
107
  args.ckpt,
108
  binary=args.binary,
 
132
  # reload the original model
133
  pipe = StableDiffusionXLPipeline.from_pretrained(
134
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
135
+ ).to(device)
136
 
137
  # get model param summary
138
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")