Spaces:
Running
on
A100
Running
on
A100
fix
Browse files
server/pipelines/controlnetSDTurbo.py
CHANGED
@@ -160,20 +160,19 @@ class Pipeline:
|
|
160 |
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
161 |
controlnet_canny = ControlNetModel.from_pretrained(
|
162 |
controlnet_model, torch_dtype=torch_dtype
|
163 |
-
)
|
164 |
-
|
165 |
self.pipes = {}
|
166 |
|
167 |
if args.safety_checker:
|
168 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
169 |
-
base_model,
|
170 |
-
controlnet=controlnet_canny,
|
171 |
)
|
172 |
else:
|
173 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
174 |
base_model,
|
175 |
controlnet=controlnet_canny,
|
176 |
safety_checker=None,
|
|
|
177 |
)
|
178 |
|
179 |
if args.taesd:
|
@@ -207,7 +206,7 @@ class Pipeline:
|
|
207 |
|
208 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
209 |
self.pipe.set_progress_bar_config(disable=True)
|
210 |
-
self.pipe.to(device=device, dtype=torch_dtype)
|
211 |
if device.type != "mps":
|
212 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
213 |
|
|
|
160 |
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
161 |
controlnet_canny = ControlNetModel.from_pretrained(
|
162 |
controlnet_model, torch_dtype=torch_dtype
|
163 |
+
)
|
|
|
164 |
self.pipes = {}
|
165 |
|
166 |
if args.safety_checker:
|
167 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
168 |
+
base_model, controlnet=controlnet_canny, torch_dtype=torch_dtype
|
|
|
169 |
)
|
170 |
else:
|
171 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
172 |
base_model,
|
173 |
controlnet=controlnet_canny,
|
174 |
safety_checker=None,
|
175 |
+
torch_dtype=torch_dtype,
|
176 |
)
|
177 |
|
178 |
if args.taesd:
|
|
|
206 |
|
207 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
208 |
self.pipe.set_progress_bar_config(disable=True)
|
209 |
+
self.pipe.to(device=device, dtype=torch_dtype)
|
210 |
if device.type != "mps":
|
211 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
212 |
|