Disty0 commited on
Commit
b3e4533
·
verified ·
1 Parent(s): 1d3c711

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -0
README.md CHANGED
@@ -9,3 +9,60 @@ tags:
9
  - image-generation
10
  - flux
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  - image-generation
10
  - flux
11
  ---
12
+
13
+ Quantized to INT8 using Optimum Quanto.
14
+
15
+ ```shell
16
+ pip install diffusers optimum-quanto
17
+ ```
18
+
19
+ ```python
20
+ import json
21
+ import torch
22
+ import diffusers
23
+ import transformers
24
+ from optimum.quanto import requantize
25
+ from safetensors.torch import load_file
26
+ from huggingface_hub import hf_hub_download
27
+
28
+
29
+ def load_quanto_transformer(repo_path):
30
+ with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f:
31
+ quantization_map = json.load(f)
32
+ with torch.device("meta"):
33
+ transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
34
+ state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
35
+ requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
36
+ return transformer
37
+
38
+
39
+ def load_quanto_text_encoder_2(repo_path):
40
+ with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f:
41
+ quantization_map = json.load(f)
42
+ with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f:
43
+ t5_config = transformers.T5Config(**json.load(f))
44
+ with torch.device("meta"):
45
+ text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
46
+ state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
47
+ requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
48
+ return text_encoder_2
49
+
50
+
51
+ pipe = diffusers.AutoPipelineForText2Image.from_pretrained("Disty0/FLUX.1-dev-qint8", transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16)
52
+ pipe.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint8")
53
+ pipe.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint8")
54
+ pipe = pipe.to("cuda", dtype=torch.bfloat16)
55
+
56
+
57
+ prompt = "A cat holding a sign that says hello world"
58
+ image = pipe(
59
+ prompt,
60
+ height=1024,
61
+ width=1024,
62
+ guidance_scale=3.5,
63
+ num_inference_steps=50,
64
+ max_sequence_length=512,
65
+ generator=torch.Generator("cpu").manual_seed(0)
66
+ ).images[0]
67
+ image.save("flux-dev.png")
68
+ ```