English
John6666 commited on
Commit
e4e2447
·
verified ·
1 Parent(s): 316e239

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +8 -0
  2. handler.py +139 -0
  3. requirements.txt +15 -0
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
5
+ language:
6
+ - en
7
+ inference: true
8
+ ---
handler.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
+ from PIL import Image
8
+ import torch
9
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
+ from huggingface_hub import hf_hub_download
11
+ import gc
12
+
13
+ import subprocess
14
+ subprocess.run("pip list", shell=True)
15
+
16
+ IS_COMPILE = True
17
+ IS_TURBO = False
18
+ IS_4BIT = True
19
+
20
+ #if IS_COMPILE:
21
+ # import torch._dynamo
22
+ # torch._dynamo.config.suppress_errors = True
23
+
24
+ from huggingface_inference_toolkit.logging import logger
25
+
26
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
27
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
28
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
29
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
30
+ pipe.transformer.fuse_qkv_projections()
31
+ pipe.vae.fuse_qkv_projections()
32
+ pipe.to("cuda")
33
+ return pipe
34
+
35
+ def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
36
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
37
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
38
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
39
+ pipe.transformer.fuse_qkv_projections()
40
+ pipe.vae.fuse_qkv_projections()
41
+ pipe.transformer.to(memory_format=torch.channels_last)
42
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
43
+ pipe.vae.to(memory_format=torch.channels_last)
44
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
45
+ pipe.to("cuda")
46
+ return pipe
47
+
48
+ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
49
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
50
+ pipe.transformer.fuse_qkv_projections()
51
+ pipe.vae.fuse_qkv_projections()
52
+ pipe.transformer.to(memory_format=torch.channels_last)
53
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
54
+ pipe.vae.to(memory_format=torch.channels_last)
55
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
56
+ pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
57
+ pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
58
+ pipe.to("cuda")
59
+ return pipe
60
+
61
+ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
62
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
63
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
64
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
65
+ pipe.fuse_lora()
66
+ pipe.unload_lora_weights()
67
+ pipe.transformer.fuse_qkv_projections()
68
+ pipe.vae.fuse_qkv_projections()
69
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
70
+ quantize_(pipe.transformer, weight, device="cuda")
71
+ quantize_(pipe.vae, weight, device="cuda")
72
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
73
+ pipe.to("cuda")
74
+ return pipe
75
+
76
+ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
77
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
78
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
79
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
80
+ pipe.fuse_lora()
81
+ pipe.unload_lora_weights()
82
+ pipe.transformer.fuse_qkv_projections()
83
+ pipe.vae.fuse_qkv_projections()
84
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
85
+ quantize_(pipe.transformer, weight, device="cuda")
86
+ quantize_(pipe.vae, weight, device="cuda")
87
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
88
+ pipe.transformer.to(memory_format=torch.channels_last)
89
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
90
+ pipe.vae.to(memory_format=torch.channels_last)
91
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
92
+ pipe.to("cuda")
93
+ return pipe
94
+
95
+ class EndpointHandler:
96
+ def __init__(self, path=""):
97
+ repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
98
+ dtype = torch.bfloat16
99
+ #dtype = torch.float16 # for older nVidia GPUs
100
+ if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
101
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+
105
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
106
+ logger.info(f"Received incoming request with {data=}")
107
+
108
+ if "inputs" in data and isinstance(data["inputs"], str):
109
+ prompt = data.pop("inputs")
110
+ elif "prompt" in data and isinstance(data["prompt"], str):
111
+ prompt = data.pop("prompt")
112
+ else:
113
+ raise ValueError(
114
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
115
+ " prompt to use for the image generation, and it needs to be a non-empty string."
116
+ )
117
+
118
+ parameters = data.pop("parameters", {})
119
+
120
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
121
+ width = parameters.get("width", 1024)
122
+ height = parameters.get("height", 1024)
123
+ guidance_scale = parameters.get("guidance_scale", 3.5)
124
+
125
+ # seed generator (seed cannot be provided as is but via a generator)
126
+ seed = parameters.get("seed", 0)
127
+ generator = torch.manual_seed(seed)
128
+
129
+ return self.pipeline( # type: ignore
130
+ prompt,
131
+ height=height,
132
+ width=width,
133
+ guidance_scale=guidance_scale,
134
+ num_inference_steps=num_inference_steps,
135
+ generator=generator,
136
+ output_type="pil",
137
+ ).images[0]
138
+
139
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.6.0+cu121
3
+ torchvision
4
+ torchaudio
5
+ huggingface_hub
6
+ torchao==0.9.0
7
+ diffusers==0.32.2
8
+ peft
9
+ transformers
10
+ numpy
11
+ scipy
12
+ Pillow
13
+ sentencepiece
14
+ protobuf
15
+ triton