Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
0bb8ff5
1
Parent(s):
3a104e9
change device
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from src.utils import (
|
|
12 |
)
|
13 |
from diffusers import StableDiffusionXLPipeline
|
14 |
|
|
|
15 |
|
16 |
def get_model_param_summary(model, verbose=False):
|
17 |
params_dict = dict()
|
@@ -29,7 +30,6 @@ def get_model_param_summary(model, verbose=False):
|
|
29 |
@dataclass
|
30 |
class GradioArgs:
|
31 |
ckpt: str = "./mask/ff.pt"
|
32 |
-
device: str = "cuda:0"
|
33 |
seed: list = None
|
34 |
prompt: str = None
|
35 |
mix_precision: str = "bf16"
|
@@ -95,9 +95,8 @@ def binary_mask_eval(args):
|
|
95 |
# load sdxl model
|
96 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
97 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
98 |
-
).to(
|
99 |
|
100 |
-
device = args.device
|
101 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
102 |
mask_pipe, hookers = create_pipeline(
|
103 |
pipe,
|
@@ -132,7 +131,7 @@ def binary_mask_eval(args):
|
|
132 |
# reload the original model
|
133 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
134 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
135 |
-
).to(
|
136 |
|
137 |
# get model param summary
|
138 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
@@ -143,9 +142,9 @@ def binary_mask_eval(args):
|
|
143 |
@spaces.GPU
|
144 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
145 |
# Run the model and return images directly
|
146 |
-
g_cpu = torch.Generator(
|
147 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
148 |
-
g_cpu = torch.Generator(
|
149 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
150 |
return original_image, ecodiff_image
|
151 |
|
|
|
12 |
)
|
13 |
from diffusers import StableDiffusionXLPipeline
|
14 |
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
def get_model_param_summary(model, verbose=False):
|
18 |
params_dict = dict()
|
|
|
30 |
@dataclass
|
31 |
class GradioArgs:
|
32 |
ckpt: str = "./mask/ff.pt"
|
|
|
33 |
seed: list = None
|
34 |
prompt: str = None
|
35 |
mix_precision: str = "bf16"
|
|
|
95 |
# load sdxl model
|
96 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
97 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
98 |
+
).to(device)
|
99 |
|
|
|
100 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
101 |
mask_pipe, hookers = create_pipeline(
|
102 |
pipe,
|
|
|
131 |
# reload the original model
|
132 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
133 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
134 |
+
).to(device)
|
135 |
|
136 |
# get model param summary
|
137 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
|
142 |
@spaces.GPU
|
143 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
144 |
# Run the model and return images directly
|
145 |
+
g_cpu = torch.Generator(device).manual_seed(seed)
|
146 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
147 |
+
g_cpu = torch.Generator(device).manual_seed(seed)
|
148 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
149 |
return original_image, ecodiff_image
|
150 |
|