jokerbit commited on
Commit
a4d6c7b
·
verified ·
1 Parent(s): 4c08fae

Upload src/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipeline.py +32 -26
src/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- import gc
2
  import os
3
  from typing import TypeAlias
4
 
@@ -10,6 +10,7 @@ from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel
 
13
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
@@ -28,37 +29,42 @@ TinyVAE = "madebyollin/taef1"
28
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
29
 
30
 
31
- def load_pipeline() -> Pipeline:
32
- path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
33
- transformer = FluxTransformer2DModel.from_pretrained(
34
- path,
35
- use_safetensors=False,
36
- local_files_only=True,
37
- torch_dtype=torch.bfloat16)
38
- vae = AutoencoderTiny.from_pretrained(
39
- TinyVAE,
40
- revision=TinyVAE_REV,
41
- local_files_only=True,
42
- torch_dtype=torch.bfloat16
43
- )
44
- pipeline = FluxPipeline.from_pretrained(
45
- CHECKPOINT,
46
- revision=REVISION,
47
- transformer=transformer,
48
- vae=vae,
49
  local_files_only=True,
50
- torch_dtype=torch.bfloat16,
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- pipeline.transformer.to(memory_format=torch.channels_last)
54
- pipeline.set_progress_bar_config(disable=True)
55
- pipeline.to("cuda")
56
- quantize_(pipeline.vae, int8_weight_only())
57
  for _ in range(4):
58
  pipeline("cat", num_inference_steps=4)
59
-
60
  return pipeline
61
 
 
62
  @torch.inference_mode()
63
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:
64
 
 
1
+ # onediff.compile
2
  import os
3
  from typing import TypeAlias
4
 
 
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel
13
+ from functools import partial
14
 
15
 
16
  Pipeline: TypeAlias = FluxPipeline
 
29
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
30
 
31
 
32
+ my_quantize = partial(quantize_, apply_tensor_subclass=int8_weight_only())
33
+
34
+
35
+ path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
36
+ transformer = FluxTransformer2DModel.from_pretrained(
37
+ path,
38
+ use_safetensors=False,
 
 
 
 
 
 
 
 
 
 
 
39
  local_files_only=True,
40
+ torch_dtype=torch.bfloat16)
41
+ vae = AutoencoderTiny.from_pretrained(
42
+ TinyVAE,
43
+ revision=TinyVAE_REV,
44
+ local_files_only=True,
45
+ torch_dtype=torch.bfloat16
46
+ )
47
+ pipeline = FluxPipeline.from_pretrained(
48
+ CHECKPOINT,
49
+ revision=REVISION,
50
+ transformer=transformer,
51
+ vae=vae,
52
+ local_files_only=True,
53
+ torch_dtype=torch.bfloat16,
54
+ )
55
+
56
+ pipeline.transformer.to(memory_format=torch.channels_last)
57
+ pipeline.set_progress_bar_config(disable=True)
58
+ pipeline.to("cuda")
59
+ my_quantize(pipeline.vae)
60
 
61
+
62
+ def load_pipeline():
 
 
63
  for _ in range(4):
64
  pipeline("cat", num_inference_steps=4)
 
65
  return pipeline
66
 
67
+
68
  @torch.inference_mode()
69
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:
70