Update src/pipeline.py
Browse files- src/pipeline.py +3 -1
src/pipeline.py
CHANGED
@@ -47,7 +47,9 @@ def load_pipeline() -> Pipeline:
|
|
47 |
text_encoder_2=text_encoder_2,
|
48 |
torch_dtype=dtype,
|
49 |
).to(device)
|
50 |
-
|
|
|
|
|
51 |
#quantize_(pipeline.vae, int8_weight_only())
|
52 |
for _ in range(3):
|
53 |
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|
|
|
47 |
text_encoder_2=text_encoder_2,
|
48 |
torch_dtype=dtype,
|
49 |
).to(device)
|
50 |
+
#torch.compile(model: None = None, *, fullgraph: bool = False, dynamic: Optional[bool] = None, backend: Union[str, Callable] = 'inductor', mode: Optional[str] = None, options: Optional[Dict[str, Union[str, int, bool]]] = None, disable: bool = False) → Callable[[Callable[[_InputT], _RetT]], Callable[[_InputT], _RetT]]
|
51 |
+
|
52 |
+
pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True, mode="max-autotune")
|
53 |
#quantize_(pipeline.vae, int8_weight_only())
|
54 |
for _ in range(3):
|
55 |
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|