zhangyang-0123 commited on
Commit
3c5de4a
·
1 Parent(s): 0bb8ff5
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -95,13 +95,13 @@ def binary_mask_eval(args):
95
  # load sdxl model
96
  pipe = StableDiffusionXLPipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
98
- ).to(device)
99
 
100
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
101
  mask_pipe, hookers = create_pipeline(
102
  pipe,
103
  args.model,
104
- device,
105
  torch_dtype,
106
  args.ckpt,
107
  binary=args.binary,
@@ -131,7 +131,7 @@ def binary_mask_eval(args):
131
  # reload the original model
132
  pipe = StableDiffusionXLPipeline.from_pretrained(
133
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
134
- ).to(device)
135
 
136
  # get model param summary
137
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
 
95
  # load sdxl model
96
  pipe = StableDiffusionXLPipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
98
+ )#.to(device)
99
 
100
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
101
  mask_pipe, hookers = create_pipeline(
102
  pipe,
103
  args.model,
104
+ "cpu",
105
  torch_dtype,
106
  args.ckpt,
107
  binary=args.binary,
 
131
  # reload the original model
132
  pipe = StableDiffusionXLPipeline.from_pretrained(
133
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
134
+ )#.to(device)
135
 
136
  # get model param summary
137
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")