zhangyang-0123 commited on
Commit
0bb8ff5
·
1 Parent(s): 3a104e9

change device

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -12,6 +12,7 @@ from src.utils import (
12
  )
13
  from diffusers import StableDiffusionXLPipeline
14
 
 
15
 
16
  def get_model_param_summary(model, verbose=False):
17
  params_dict = dict()
@@ -29,7 +30,6 @@ def get_model_param_summary(model, verbose=False):
29
  @dataclass
30
  class GradioArgs:
31
  ckpt: str = "./mask/ff.pt"
32
- device: str = "cuda:0"
33
  seed: list = None
34
  prompt: str = None
35
  mix_precision: str = "bf16"
@@ -95,9 +95,8 @@ def binary_mask_eval(args):
95
  # load sdxl model
96
  pipe = StableDiffusionXLPipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
98
- ).to(args.device)
99
 
100
- device = args.device
101
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
102
  mask_pipe, hookers = create_pipeline(
103
  pipe,
@@ -132,7 +131,7 @@ def binary_mask_eval(args):
132
  # reload the original model
133
  pipe = StableDiffusionXLPipeline.from_pretrained(
134
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
135
- ).to(args.device)
136
 
137
  # get model param summary
138
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
@@ -143,9 +142,9 @@ def binary_mask_eval(args):
143
  @spaces.GPU
144
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
145
  # Run the model and return images directly
146
- g_cpu = torch.Generator("cuda:0").manual_seed(seed)
147
  original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
148
- g_cpu = torch.Generator("cuda:0").manual_seed(seed)
149
  ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
150
  return original_image, ecodiff_image
151
 
 
12
  )
13
  from diffusers import StableDiffusionXLPipeline
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  def get_model_param_summary(model, verbose=False):
18
  params_dict = dict()
 
30
  @dataclass
31
  class GradioArgs:
32
  ckpt: str = "./mask/ff.pt"
 
33
  seed: list = None
34
  prompt: str = None
35
  mix_precision: str = "bf16"
 
95
  # load sdxl model
96
  pipe = StableDiffusionXLPipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
98
+ ).to(device)
99
 
 
100
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
101
  mask_pipe, hookers = create_pipeline(
102
  pipe,
 
131
  # reload the original model
132
  pipe = StableDiffusionXLPipeline.from_pretrained(
133
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
134
+ ).to(device)
135
 
136
  # get model param summary
137
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
 
142
  @spaces.GPU
143
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
144
  # Run the model and return images directly
145
+ g_cpu = torch.Generator(device).manual_seed(seed)
146
  original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
147
+ g_cpu = torch.Generator(device).manual_seed(seed)
148
  ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
149
  return original_image, ecodiff_image
150