Update src/pipeline.py
Browse files- src/pipeline.py +12 -23
src/pipeline.py
CHANGED
@@ -1314,35 +1314,24 @@ class StableDiffusionXLPipeline(
|
|
1314 |
|
1315 |
return StableDiffusionXLPipelineOutput(images=image)
|
1316 |
|
1317 |
-
|
1318 |
-
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
|
|
|
|
1324 |
pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16).to('cuda')
|
1325 |
-
|
1326 |
-
# xformers and Triton are suggested for achieving best performance.
|
1327 |
-
try:
|
1328 |
-
import xformers
|
1329 |
-
config.enable_xformers = True
|
1330 |
-
except ImportError:
|
1331 |
-
print('xformers not installed, skip')
|
1332 |
-
try:
|
1333 |
-
import triton
|
1334 |
-
config.enable_triton = True
|
1335 |
-
except ImportError:
|
1336 |
-
print('Triton not installed, skip')
|
1337 |
-
config.enable_cuda_graph = True
|
1338 |
-
|
1339 |
-
pipeline = compile(pipeline, config)
|
1340 |
for _ in range(2):
|
1341 |
pipeline(prompt="", num_inference_steps=10)
|
1342 |
|
1343 |
return pipeline
|
1344 |
|
1345 |
|
|
|
1346 |
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
|
1347 |
generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None
|
1348 |
|
@@ -1352,5 +1341,5 @@ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> I
|
|
1352 |
width=request.width,
|
1353 |
height=request.height,
|
1354 |
generator=generator,
|
1355 |
-
num_inference_steps=
|
1356 |
).images[0]
|
|
|
1314 |
|
1315 |
return StableDiffusionXLPipelineOutput(images=image)
|
1316 |
|
1317 |
+
from onediffx import compile_pipe
|
1318 |
+
|
1319 |
+
def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
|
1320 |
+
if not pipeline:
|
1321 |
+
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
1322 |
+
"./models/newdream-sdxl-20",
|
1323 |
+
torch_dtype=torch.float16,
|
1324 |
+
local_files_only=True,
|
1325 |
+
).to("cuda")
|
1326 |
pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16).to('cuda')
|
1327 |
+
pipeline = compile_pipe(pipeline)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1328 |
for _ in range(2):
|
1329 |
pipeline(prompt="", num_inference_steps=10)
|
1330 |
|
1331 |
return pipeline
|
1332 |
|
1333 |
|
1334 |
+
|
1335 |
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
|
1336 |
generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None
|
1337 |
|
|
|
1341 |
width=request.width,
|
1342 |
height=request.height,
|
1343 |
generator=generator,
|
1344 |
+
num_inference_steps=10,
|
1345 |
).images[0]
|