Yuanshi commited on
Commit
029e8ff
·
verified ·
1 Parent(s): 049826f
Files changed (4) hide show
  1. app.py +94 -59
  2. attention_processor.py +253 -0
  3. pipeline_flux.py +789 -0
  4. transformer_flux.py +560 -0
app.py CHANGED
@@ -2,37 +2,65 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
7
  import torch
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
 
 
 
 
 
 
18
  pipe = pipe.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
22
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
- negative_prompt,
28
  seed,
29
  randomize_seed,
30
  width,
31
  height,
32
- guidance_scale,
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
 
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
@@ -40,17 +68,22 @@ def infer(
40
 
41
  image = pipe(
42
  prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
  num_inference_steps=num_inference_steps,
46
  width=width,
47
  height=height,
 
 
 
48
  generator=generator,
49
  ).images[0]
50
 
51
  return image, seed
52
 
53
 
 
 
 
54
  examples = [
55
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
  "An astronaut riding a green horse",
@@ -58,17 +91,25 @@ examples = [
58
  ]
59
 
60
  css = """
61
- #col-container {
 
 
 
 
 
 
 
 
62
  margin: 0 auto;
63
- max-width: 640px;
64
  }
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
  prompt = gr.Text(
73
  label="Prompt",
74
  show_label=False,
@@ -77,35 +118,27 @@ with gr.Blocks(css=css) as demo:
77
  container=False,
78
  )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
  with gr.Row():
103
  width = gr.Slider(
104
  label="Width",
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
 
111
  height = gr.Slider(
@@ -113,38 +146,40 @@ with gr.Blocks(css=css) as demo:
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
  inputs=[
141
  prompt,
142
- negative_prompt,
143
  seed,
144
  randomize_seed,
145
  width,
146
  height,
147
- guidance_scale,
148
  num_inference_steps,
149
  ],
150
  outputs=[result, seed],
 
2
  import numpy as np
3
  import random
4
 
5
+ import spaces
6
+ from pipeline_flux import FluxPipeline
7
+ from transformer_flux import FluxTransformer2DModel
8
  import torch
9
+ from patch_conv import convert_model
10
 
11
+ flux_model = "schnell"
12
+ bfl_repo = f"black-forest-labs/FLUX.1-{flux_model}"
 
 
 
 
 
13
 
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16
16
+
17
+ transformer = FluxTransformer2DModel.from_pretrained(
18
+ bfl_repo, subfolder="transformer", torch_dtype=dtype
19
+ )
20
+ pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
21
+ pipe.transformer = transformer
22
+ pipe.scheduler.config.use_dynamic_shifting = False
23
+ pipe.scheduler.config.time_shift = 10
24
+ # pipe.enable_model_cpu_offload()
25
  pipe = pipe.to(device)
26
 
27
+ pipe.load_lora_weights(
28
+ "Huage001/URAE",
29
+ weight_name="urae_2k_adapter.safetensors",
30
+ adapter_name="2k",
31
+ )
32
+ pipe.load_lora_weights(
33
+ "Huage001/URAE",
34
+ weight_name="urae_4k_adapter_lora_conversion_dev.safetensors",
35
+ adapter_name="4k_dev",
36
+ )
37
+ pipe.load_lora_weights(
38
+ "Huage001/URAE",
39
+ weight_name="urae_4k_adapter_lora_conversion_schnell.safetensors",
40
+ adapter_name="4k_schnell",
41
+ )
42
  MAX_SEED = np.iinfo(np.int32).max
43
+ MAX_IMAGE_SIZE = 4096
44
+ USE_ZERO_GPU = True
45
 
46
 
47
  # @spaces.GPU #[uncomment to use ZeroGPU]
48
  def infer(
49
  prompt,
50
+ model,
51
  seed,
52
  randomize_seed,
53
  width,
54
  height,
 
55
  num_inference_steps,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
+ print("Using model:", model)
59
+ if model == "2k":
60
+ pipe.set_adapters("2k")
61
+ elif model == "4k":
62
+ pipe.set_adapters(f"4k_{flux_model}")
63
+
64
  if randomize_seed:
65
  seed = random.randint(0, MAX_SEED)
66
 
 
68
 
69
  image = pipe(
70
  prompt=prompt,
71
+ guidance_scale=0,
 
72
  num_inference_steps=num_inference_steps,
73
  width=width,
74
  height=height,
75
+ max_sequence_length=256,
76
+ ntk_factor=10,
77
+ proportional_attention=True,
78
  generator=generator,
79
  ).images[0]
80
 
81
  return image, seed
82
 
83
 
84
+ if USE_ZERO_GPU:
85
+ infer = spaces.GPU(infer)
86
+
87
  examples = [
88
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
89
  "An astronaut riding a green horse",
 
91
  ]
92
 
93
  css = """
94
+ #maincontainer {
95
+ display: flex;
96
+ }
97
+
98
+ #col1 {
99
+ margin: 0 auto;
100
+ max-width: 50%;
101
+ }
102
+ #col2 {
103
  margin: 0 auto;
104
+ # max-width: 40px;
105
  }
106
  """
107
 
108
  with gr.Blocks(css=css) as demo:
109
+ gr.Markdown("# URAE: ")
110
+ with gr.Row(elem_id="maincontainer"):
111
+ with gr.Column(elem_id="col1"):
112
+ gr.Markdown("### Prompt:")
113
  prompt = gr.Text(
114
  label="Prompt",
115
  show_label=False,
 
118
  container=False,
119
  )
120
 
121
+ gr.Examples(examples=examples, inputs=[prompt])
122
+ run_button = gr.Button("Generate", scale=1, variant="primary")
123
 
124
+ gr.Markdown("### Setting:")
125
 
126
+ model = gr.Radio(
127
+ label="Model",
128
+ choices=[
129
+ ("2K model", "2k"),
130
+ ("4K model (beta)", "4k"),
131
+ ],
132
+ value="2k",
 
 
 
 
 
 
 
133
  )
134
 
 
 
135
  with gr.Row():
136
  width = gr.Slider(
137
  label="Width",
138
  minimum=256,
139
  maximum=MAX_IMAGE_SIZE,
140
  step=32,
141
+ value=2048, # Replace with defaults that work for your model
142
  )
143
 
144
  height = gr.Slider(
 
146
  minimum=256,
147
  maximum=MAX_IMAGE_SIZE,
148
  step=32,
149
+ value=2048, # Replace with defaults that work for your model
150
  )
151
 
152
+ seed = gr.Slider(
153
+ label="Seed",
154
+ minimum=0,
155
+ maximum=MAX_SEED,
156
+ step=1,
157
+ value=0,
158
+ )
 
159
 
160
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
161
+
162
+ num_inference_steps = gr.Slider(
163
+ label="Number of inference steps",
164
+ minimum=1,
165
+ maximum=50,
166
+ step=1,
167
+ value=4, # Replace with defaults that work for your model
168
+ )
169
+
170
+ with gr.Column(elem_id="col2"):
171
+ result = gr.Image(label="Result", show_label=False)
172
 
 
173
  gr.on(
174
  triggers=[run_button.click, prompt.submit],
175
  fn=infer,
176
  inputs=[
177
  prompt,
178
+ model,
179
  seed,
180
  randomize_seed,
181
  width,
182
  height,
 
183
  num_inference_steps,
184
  ],
185
  outputs=[result, seed],
attention_processor.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from diffusers.models.attention_processor import Attention
6
+ from typing import Optional
7
+ from diffusers.models.embeddings import apply_rotary_emb
8
+
9
+
10
+ class FluxAttnProcessor2_0:
11
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
12
+
13
+ def __init__(self, train_seq_len=512 + 64 * 64):
14
+ if not hasattr(F, "scaled_dot_product_attention"):
15
+ raise ImportError(
16
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
17
+ )
18
+ self.train_seq_len = train_seq_len
19
+
20
+ def __call__(
21
+ self,
22
+ attn: Attention,
23
+ hidden_states: torch.FloatTensor,
24
+ encoder_hidden_states: torch.FloatTensor = None,
25
+ attention_mask: Optional[torch.FloatTensor] = None,
26
+ image_rotary_emb: Optional[torch.Tensor] = None,
27
+ proportional_attention=False,
28
+ ) -> torch.FloatTensor:
29
+ batch_size, _, _ = (
30
+ hidden_states.shape
31
+ if encoder_hidden_states is None
32
+ else encoder_hidden_states.shape
33
+ )
34
+
35
+ # `sample` projections.
36
+ query = attn.to_q(hidden_states)
37
+ key = attn.to_k(hidden_states)
38
+ value = attn.to_v(hidden_states)
39
+
40
+ inner_dim = key.shape[-1]
41
+ head_dim = inner_dim // attn.heads
42
+
43
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
44
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
45
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
46
+
47
+ if attn.norm_q is not None:
48
+ query = attn.norm_q(query)
49
+ if attn.norm_k is not None:
50
+ key = attn.norm_k(key)
51
+
52
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
53
+ if encoder_hidden_states is not None:
54
+ # `context` projections.
55
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
56
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
57
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
58
+
59
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
60
+ batch_size, -1, attn.heads, head_dim
61
+ ).transpose(1, 2)
62
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
63
+ batch_size, -1, attn.heads, head_dim
64
+ ).transpose(1, 2)
65
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
66
+ batch_size, -1, attn.heads, head_dim
67
+ ).transpose(1, 2)
68
+
69
+ if attn.norm_added_q is not None:
70
+ encoder_hidden_states_query_proj = attn.norm_added_q(
71
+ encoder_hidden_states_query_proj
72
+ )
73
+ if attn.norm_added_k is not None:
74
+ encoder_hidden_states_key_proj = attn.norm_added_k(
75
+ encoder_hidden_states_key_proj
76
+ )
77
+
78
+ # attention
79
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
80
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
81
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
82
+
83
+ if image_rotary_emb is not None:
84
+ query = apply_rotary_emb(query, image_rotary_emb)
85
+ key = apply_rotary_emb(key, image_rotary_emb)
86
+
87
+ if proportional_attention:
88
+ attention_scale = math.sqrt(
89
+ math.log(key.size(2), self.train_seq_len) / head_dim
90
+ )
91
+ else:
92
+ attention_scale = math.sqrt(1 / head_dim)
93
+
94
+ hidden_states = F.scaled_dot_product_attention(
95
+ query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale
96
+ )
97
+ hidden_states = hidden_states.transpose(1, 2).reshape(
98
+ batch_size, -1, attn.heads * head_dim
99
+ )
100
+ hidden_states = hidden_states.to(query.dtype)
101
+
102
+ if encoder_hidden_states is not None:
103
+ encoder_hidden_states, hidden_states = (
104
+ hidden_states[:, : encoder_hidden_states.shape[1]],
105
+ hidden_states[:, encoder_hidden_states.shape[1] :],
106
+ )
107
+
108
+ # linear proj
109
+ hidden_states = attn.to_out[0](hidden_states)
110
+ # dropout
111
+ hidden_states = attn.to_out[1](hidden_states)
112
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
113
+
114
+ return hidden_states, encoder_hidden_states
115
+ else:
116
+ return hidden_states
117
+
118
+
119
+ class FluxAttnAdaptationProcessor2_0(nn.Module):
120
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
121
+
122
+ def __init__(self, rank=16, dim=3072, to_out=False, train_seq_len=512 + 64 * 64):
123
+ super().__init__()
124
+ if not hasattr(F, "scaled_dot_product_attention"):
125
+ raise ImportError(
126
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
127
+ )
128
+ self.to_q_a = nn.Linear(dim, rank, bias=False)
129
+ self.to_q_b = nn.Linear(rank, dim, bias=False)
130
+ self.to_q_b.weight.data = torch.zeros_like(self.to_q_b.weight.data)
131
+ self.to_k_a = nn.Linear(dim, rank, bias=False)
132
+ self.to_k_b = nn.Linear(rank, dim, bias=False)
133
+ self.to_k_b.weight.data = torch.zeros_like(self.to_k_b.weight.data)
134
+ self.to_v_a = nn.Linear(dim, rank, bias=False)
135
+ self.to_v_b = nn.Linear(rank, dim, bias=False)
136
+ self.to_v_b.weight.data = torch.zeros_like(self.to_v_b.weight.data)
137
+ if to_out:
138
+ self.to_out_a = nn.Linear(dim, rank, bias=False)
139
+ self.to_out_b = nn.Linear(rank, dim, bias=False)
140
+ self.to_out_b.weight.data = torch.zeros_like(self.to_out_b.weight.data)
141
+ self.train_seq_len = train_seq_len
142
+
143
+ def __call__(
144
+ self,
145
+ attn: Attention,
146
+ hidden_states: torch.FloatTensor,
147
+ encoder_hidden_states: torch.FloatTensor = None,
148
+ attention_mask: Optional[torch.FloatTensor] = None,
149
+ image_rotary_emb: Optional[torch.Tensor] = None,
150
+ proportional_attention=False,
151
+ ) -> torch.FloatTensor:
152
+ batch_size, _, _ = (
153
+ hidden_states.shape
154
+ if encoder_hidden_states is None
155
+ else encoder_hidden_states.shape
156
+ )
157
+
158
+ use_adaptation = True
159
+
160
+ # `sample` projections.
161
+ query = attn.to_q(hidden_states)
162
+ key = attn.to_k(hidden_states)
163
+ value = attn.to_v(hidden_states)
164
+
165
+ if use_adaptation:
166
+ query += self.to_q_b(self.to_q_a(hidden_states))
167
+ key += self.to_k_b(self.to_k_a(hidden_states))
168
+ value += self.to_v_b(self.to_v_a(hidden_states))
169
+
170
+ inner_dim = key.shape[-1]
171
+ head_dim = inner_dim // attn.heads
172
+
173
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
174
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
+
177
+ if attn.norm_q is not None:
178
+ query = attn.norm_q(query)
179
+ if attn.norm_k is not None:
180
+ key = attn.norm_k(key)
181
+
182
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
183
+ if encoder_hidden_states is not None:
184
+ # `context` projections.
185
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
186
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
187
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
188
+
189
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
190
+ batch_size, -1, attn.heads, head_dim
191
+ ).transpose(1, 2)
192
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
193
+ batch_size, -1, attn.heads, head_dim
194
+ ).transpose(1, 2)
195
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
196
+ batch_size, -1, attn.heads, head_dim
197
+ ).transpose(1, 2)
198
+
199
+ if attn.norm_added_q is not None:
200
+ encoder_hidden_states_query_proj = attn.norm_added_q(
201
+ encoder_hidden_states_query_proj
202
+ )
203
+ if attn.norm_added_k is not None:
204
+ encoder_hidden_states_key_proj = attn.norm_added_k(
205
+ encoder_hidden_states_key_proj
206
+ )
207
+
208
+ # attention
209
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
210
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
211
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
212
+
213
+ if image_rotary_emb is not None:
214
+ query = apply_rotary_emb(query, image_rotary_emb)
215
+ key = apply_rotary_emb(key, image_rotary_emb)
216
+
217
+ if proportional_attention:
218
+ attention_scale = math.sqrt(
219
+ math.log(key.size(2), self.train_seq_len) / head_dim
220
+ )
221
+ else:
222
+ attention_scale = math.sqrt(1 / head_dim)
223
+
224
+ hidden_states = F.scaled_dot_product_attention(
225
+ query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale
226
+ )
227
+ hidden_states = hidden_states.transpose(1, 2).reshape(
228
+ batch_size, -1, attn.heads * head_dim
229
+ )
230
+ hidden_states = hidden_states.to(query.dtype)
231
+
232
+ if encoder_hidden_states is not None:
233
+ encoder_hidden_states, hidden_states = (
234
+ hidden_states[:, : encoder_hidden_states.shape[1]],
235
+ hidden_states[:, encoder_hidden_states.shape[1] :],
236
+ )
237
+
238
+ # linear proj
239
+ hidden_states = (
240
+ (
241
+ attn.to_out[0](hidden_states)
242
+ + self.to_out_b(self.to_out_a(hidden_states))
243
+ )
244
+ if use_adaptation
245
+ else attn.to_out[0](hidden_states)
246
+ )
247
+ # dropout
248
+ hidden_states = attn.to_out[1](hidden_states)
249
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
250
+
251
+ return hidden_states, encoder_hidden_states
252
+ else:
253
+ return hidden_states
pipeline_flux.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import FluxTransformer2DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers import FluxPipeline
55
+
56
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57
+ >>> pipe.to("cuda")
58
+ >>> prompt = "A cat holding a sign that says hello world"
59
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
60
+ >>> # Refer to the pipeline documentation for more details.
61
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62
+ >>> image.save("flux.png")
63
+ ```
64
+ """
65
+
66
+
67
+ def calculate_shift(
68
+ image_seq_len,
69
+ base_seq_len: int = 256,
70
+ max_seq_len: int = 4096,
71
+ base_shift: float = 0.5,
72
+ max_shift: float = 1.16,
73
+ ):
74
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75
+ b = base_shift - m * base_seq_len
76
+ mu = image_seq_len * m + b
77
+ return mu
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class FluxPipeline(
141
+ DiffusionPipeline,
142
+ FluxLoraLoaderMixin,
143
+ FromSingleFileMixin,
144
+ TextualInversionLoaderMixin,
145
+ ):
146
+ r"""
147
+ The Flux pipeline for text-to-image generation.
148
+
149
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
150
+
151
+ Args:
152
+ transformer ([`FluxTransformer2DModel`]):
153
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
154
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
155
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
156
+ vae ([`AutoencoderKL`]):
157
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
158
+ text_encoder ([`CLIPTextModel`]):
159
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
160
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
161
+ text_encoder_2 ([`T5EncoderModel`]):
162
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
163
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
164
+ tokenizer (`CLIPTokenizer`):
165
+ Tokenizer of class
166
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
167
+ tokenizer_2 (`T5TokenizerFast`):
168
+ Second Tokenizer of class
169
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
170
+ """
171
+
172
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
173
+ _optional_components = []
174
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
175
+
176
+ def __init__(
177
+ self,
178
+ scheduler: FlowMatchEulerDiscreteScheduler,
179
+ vae: AutoencoderKL,
180
+ text_encoder: CLIPTextModel,
181
+ tokenizer: CLIPTokenizer,
182
+ text_encoder_2: T5EncoderModel,
183
+ tokenizer_2: T5TokenizerFast,
184
+ transformer: FluxTransformer2DModel,
185
+ ):
186
+ super().__init__()
187
+
188
+ self.register_modules(
189
+ vae=vae,
190
+ text_encoder=text_encoder,
191
+ text_encoder_2=text_encoder_2,
192
+ tokenizer=tokenizer,
193
+ tokenizer_2=tokenizer_2,
194
+ transformer=transformer,
195
+ scheduler=scheduler,
196
+ )
197
+ self.vae_scale_factor = (
198
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
199
+ )
200
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201
+ self.tokenizer_max_length = (
202
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203
+ )
204
+ self.default_sample_size = 64
205
+
206
+ def _get_t5_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ num_images_per_prompt: int = 1,
210
+ max_sequence_length: int = 512,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+ dtype = dtype or self.text_encoder.dtype
216
+
217
+ prompt = [prompt] if isinstance(prompt, str) else prompt
218
+ batch_size = len(prompt)
219
+
220
+ if isinstance(self, TextualInversionLoaderMixin):
221
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
222
+
223
+ text_inputs = self.tokenizer_2(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=max_sequence_length,
227
+ truncation=True,
228
+ return_length=False,
229
+ return_overflowing_tokens=False,
230
+ return_tensors="pt",
231
+ )
232
+ text_input_ids = text_inputs.input_ids
233
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
234
+
235
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
236
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
237
+ logger.warning(
238
+ "The following part of your input was truncated because `max_sequence_length` is set to "
239
+ f" {max_sequence_length} tokens: {removed_text}"
240
+ )
241
+
242
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
243
+
244
+ dtype = self.text_encoder_2.dtype
245
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
246
+
247
+ _, seq_len, _ = prompt_embeds.shape
248
+
249
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
250
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
251
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
252
+
253
+ return prompt_embeds
254
+
255
+ def _get_clip_prompt_embeds(
256
+ self,
257
+ prompt: Union[str, List[str]],
258
+ num_images_per_prompt: int = 1,
259
+ device: Optional[torch.device] = None,
260
+ ):
261
+ device = device or self._execution_device
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+ batch_size = len(prompt)
265
+
266
+ if isinstance(self, TextualInversionLoaderMixin):
267
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
268
+
269
+ text_inputs = self.tokenizer(
270
+ prompt,
271
+ padding="max_length",
272
+ max_length=self.tokenizer_max_length,
273
+ truncation=True,
274
+ return_overflowing_tokens=False,
275
+ return_length=False,
276
+ return_tensors="pt",
277
+ )
278
+
279
+ text_input_ids = text_inputs.input_ids
280
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
281
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
282
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
283
+ logger.warning(
284
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
285
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
286
+ )
287
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
288
+
289
+ # Use pooled output of CLIPTextModel
290
+ prompt_embeds = prompt_embeds.pooler_output
291
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
292
+
293
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
295
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
296
+
297
+ return prompt_embeds
298
+
299
+ def encode_prompt(
300
+ self,
301
+ prompt: Union[str, List[str]],
302
+ prompt_2: Union[str, List[str]],
303
+ device: Optional[torch.device] = None,
304
+ num_images_per_prompt: int = 1,
305
+ prompt_embeds: Optional[torch.FloatTensor] = None,
306
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
307
+ max_sequence_length: int = 512,
308
+ lora_scale: Optional[float] = None,
309
+ ):
310
+ r"""
311
+
312
+ Args:
313
+ prompt (`str` or `List[str]`, *optional*):
314
+ prompt to be encoded
315
+ prompt_2 (`str` or `List[str]`, *optional*):
316
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
317
+ used in all text-encoders
318
+ device: (`torch.device`):
319
+ torch device
320
+ num_images_per_prompt (`int`):
321
+ number of images that should be generated per prompt
322
+ prompt_embeds (`torch.FloatTensor`, *optional*):
323
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
324
+ provided, text embeddings will be generated from `prompt` input argument.
325
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
326
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
327
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
328
+ lora_scale (`float`, *optional*):
329
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
330
+ """
331
+ device = device or self._execution_device
332
+
333
+ # set lora scale so that monkey patched LoRA
334
+ # function of text encoder can correctly access it
335
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
336
+ self._lora_scale = lora_scale
337
+
338
+ # dynamically adjust the LoRA scale
339
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
340
+ scale_lora_layers(self.text_encoder, lora_scale)
341
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
342
+ scale_lora_layers(self.text_encoder_2, lora_scale)
343
+
344
+ prompt = [prompt] if isinstance(prompt, str) else prompt
345
+
346
+ if prompt_embeds is None:
347
+ prompt_2 = prompt_2 or prompt
348
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
349
+
350
+ # We only use the pooled prompt output from the CLIPTextModel
351
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
352
+ prompt=prompt,
353
+ device=device,
354
+ num_images_per_prompt=num_images_per_prompt,
355
+ )
356
+ prompt_embeds = self._get_t5_prompt_embeds(
357
+ prompt=prompt_2,
358
+ num_images_per_prompt=num_images_per_prompt,
359
+ max_sequence_length=max_sequence_length,
360
+ device=device,
361
+ )
362
+
363
+ if self.text_encoder is not None:
364
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
365
+ # Retrieve the original scale by scaling back the LoRA layers
366
+ unscale_lora_layers(self.text_encoder, lora_scale)
367
+
368
+ if self.text_encoder_2 is not None:
369
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
370
+ # Retrieve the original scale by scaling back the LoRA layers
371
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
372
+
373
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
374
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
375
+
376
+ return prompt_embeds, pooled_prompt_embeds, text_ids
377
+
378
+ def check_inputs(
379
+ self,
380
+ prompt,
381
+ prompt_2,
382
+ height,
383
+ width,
384
+ prompt_embeds=None,
385
+ pooled_prompt_embeds=None,
386
+ callback_on_step_end_tensor_inputs=None,
387
+ max_sequence_length=None,
388
+ ):
389
+ if height % 8 != 0 or width % 8 != 0:
390
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
391
+
392
+ if callback_on_step_end_tensor_inputs is not None and not all(
393
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
394
+ ):
395
+ raise ValueError(
396
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
397
+ )
398
+
399
+ if prompt is not None and prompt_embeds is not None:
400
+ raise ValueError(
401
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
402
+ " only forward one of the two."
403
+ )
404
+ elif prompt_2 is not None and prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
407
+ " only forward one of the two."
408
+ )
409
+ elif prompt is None and prompt_embeds is None:
410
+ raise ValueError(
411
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
412
+ )
413
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
414
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
415
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
416
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
417
+
418
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
419
+ raise ValueError(
420
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
421
+ )
422
+
423
+ if max_sequence_length is not None and max_sequence_length > 512:
424
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
425
+
426
+ @staticmethod
427
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
431
+
432
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433
+
434
+ latent_image_ids = latent_image_ids.reshape(
435
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
436
+ )
437
+
438
+ return latent_image_ids.to(device=device, dtype=dtype)
439
+
440
+ @staticmethod
441
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
442
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
443
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
444
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
445
+
446
+ return latents
447
+
448
+ @staticmethod
449
+ def _unpack_latents(latents, height, width, vae_scale_factor):
450
+ batch_size, num_patches, channels = latents.shape
451
+
452
+ height = height // vae_scale_factor
453
+ width = width // vae_scale_factor
454
+
455
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
456
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
457
+
458
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
459
+
460
+ return latents
461
+
462
+ def enable_vae_slicing(self):
463
+ r"""
464
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
465
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
466
+ """
467
+ self.vae.enable_slicing()
468
+
469
+ def disable_vae_slicing(self):
470
+ r"""
471
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
472
+ computing decoding in one step.
473
+ """
474
+ self.vae.disable_slicing()
475
+
476
+ def enable_vae_tiling(self):
477
+ r"""
478
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
479
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
480
+ processing larger images.
481
+ """
482
+ self.vae.enable_tiling()
483
+
484
+ def disable_vae_tiling(self):
485
+ r"""
486
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
487
+ computing decoding in one step.
488
+ """
489
+ self.vae.disable_tiling()
490
+
491
+ def prepare_latents(
492
+ self,
493
+ batch_size,
494
+ num_channels_latents,
495
+ height,
496
+ width,
497
+ dtype,
498
+ device,
499
+ generator,
500
+ latents=None,
501
+ ):
502
+ height = 2 * (int(height) // self.vae_scale_factor)
503
+ width = 2 * (int(width) // self.vae_scale_factor)
504
+
505
+ shape = (batch_size, num_channels_latents, height, width)
506
+
507
+ if latents is not None:
508
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
509
+ return latents.to(device=device, dtype=dtype), latent_image_ids
510
+
511
+ if isinstance(generator, list) and len(generator) != batch_size:
512
+ raise ValueError(
513
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515
+ )
516
+
517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519
+
520
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
521
+
522
+ return latents, latent_image_ids
523
+
524
+ @property
525
+ def guidance_scale(self):
526
+ return self._guidance_scale
527
+
528
+ @property
529
+ def joint_attention_kwargs(self):
530
+ return self._joint_attention_kwargs
531
+
532
+ @property
533
+ def num_timesteps(self):
534
+ return self._num_timesteps
535
+
536
+ @property
537
+ def interrupt(self):
538
+ return self._interrupt
539
+
540
+ @torch.no_grad()
541
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
542
+ def __call__(
543
+ self,
544
+ prompt: Union[str, List[str]] = None,
545
+ prompt_2: Optional[Union[str, List[str]]] = None,
546
+ height: Optional[int] = None,
547
+ width: Optional[int] = None,
548
+ num_inference_steps: int = 28,
549
+ timesteps: List[int] = None,
550
+ guidance_scale: float = 3.5,
551
+ num_images_per_prompt: Optional[int] = 1,
552
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
553
+ latents: Optional[torch.FloatTensor] = None,
554
+ prompt_embeds: Optional[torch.FloatTensor] = None,
555
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
556
+ output_type: Optional[str] = "pil",
557
+ return_dict: bool = True,
558
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
559
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
560
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
561
+ max_sequence_length: int = 512,
562
+ ntk_factor: float = 10.0,
563
+ proportional_attention: bool = True
564
+ ):
565
+ r"""
566
+ Function invoked when calling the pipeline for generation.
567
+
568
+ Args:
569
+ prompt (`str` or `List[str]`, *optional*):
570
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
571
+ instead.
572
+ prompt_2 (`str` or `List[str]`, *optional*):
573
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
574
+ will be used instead
575
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
576
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
577
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
578
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
579
+ num_inference_steps (`int`, *optional*, defaults to 50):
580
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
581
+ expense of slower inference.
582
+ timesteps (`List[int]`, *optional*):
583
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
584
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
585
+ passed will be used. Must be in descending order.
586
+ guidance_scale (`float`, *optional*, defaults to 7.0):
587
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
588
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
589
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
590
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
591
+ usually at the expense of lower image quality.
592
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
593
+ The number of images to generate per prompt.
594
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
595
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
596
+ to make generation deterministic.
597
+ latents (`torch.FloatTensor`, *optional*):
598
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
599
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
600
+ tensor will ge generated by sampling using the supplied random `generator`.
601
+ prompt_embeds (`torch.FloatTensor`, *optional*):
602
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
603
+ provided, text embeddings will be generated from `prompt` input argument.
604
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
605
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
606
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
607
+ output_type (`str`, *optional*, defaults to `"pil"`):
608
+ The output format of the generate image. Choose between
609
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
610
+ return_dict (`bool`, *optional*, defaults to `True`):
611
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
612
+ joint_attention_kwargs (`dict`, *optional*):
613
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
614
+ `self.processor` in
615
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
616
+ callback_on_step_end (`Callable`, *optional*):
617
+ A function that calls at the end of each denoising steps during the inference. The function is called
618
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
619
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
620
+ `callback_on_step_end_tensor_inputs`.
621
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
622
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
623
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
624
+ `._callback_tensor_inputs` attribute of your pipeline class.
625
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
626
+
627
+ Examples:
628
+
629
+ Returns:
630
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
631
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
632
+ images.
633
+ """
634
+
635
+ height = height or self.default_sample_size * self.vae_scale_factor
636
+ width = width or self.default_sample_size * self.vae_scale_factor
637
+
638
+ # 1. Check inputs. Raise error if not correct
639
+ self.check_inputs(
640
+ prompt,
641
+ prompt_2,
642
+ height,
643
+ width,
644
+ prompt_embeds=prompt_embeds,
645
+ pooled_prompt_embeds=pooled_prompt_embeds,
646
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
647
+ max_sequence_length=max_sequence_length,
648
+ )
649
+
650
+ self._guidance_scale = guidance_scale
651
+ if joint_attention_kwargs is None:
652
+ joint_attention_kwargs = {'proportional_attention': proportional_attention}
653
+ else:
654
+ joint_attention_kwargs = {**joint_attention_kwargs, 'proportional_attention': proportional_attention}
655
+ self._joint_attention_kwargs = joint_attention_kwargs
656
+ self._interrupt = False
657
+
658
+ # 2. Define call parameters
659
+ if prompt is not None and isinstance(prompt, str):
660
+ batch_size = 1
661
+ elif prompt is not None and isinstance(prompt, list):
662
+ batch_size = len(prompt)
663
+ else:
664
+ batch_size = prompt_embeds.shape[0]
665
+
666
+ device = self._execution_device
667
+
668
+ lora_scale = (
669
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
670
+ )
671
+ (
672
+ prompt_embeds,
673
+ pooled_prompt_embeds,
674
+ text_ids,
675
+ ) = self.encode_prompt(
676
+ prompt=prompt,
677
+ prompt_2=prompt_2,
678
+ prompt_embeds=prompt_embeds,
679
+ pooled_prompt_embeds=pooled_prompt_embeds,
680
+ device=device,
681
+ num_images_per_prompt=num_images_per_prompt,
682
+ max_sequence_length=max_sequence_length,
683
+ lora_scale=lora_scale,
684
+ )
685
+
686
+ # 4. Prepare latent variables
687
+ num_channels_latents = self.transformer.config.in_channels // 4
688
+ latents, latent_image_ids = self.prepare_latents(
689
+ batch_size * num_images_per_prompt,
690
+ num_channels_latents,
691
+ height,
692
+ width,
693
+ prompt_embeds.dtype,
694
+ device,
695
+ generator,
696
+ latents,
697
+ )
698
+
699
+ # 5. Prepare timesteps
700
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
701
+ image_seq_len = latents.shape[1]
702
+ mu = calculate_shift(
703
+ image_seq_len,
704
+ self.scheduler.config.base_image_seq_len,
705
+ self.scheduler.config.max_image_seq_len,
706
+ self.scheduler.config.base_shift,
707
+ self.scheduler.config.max_shift,
708
+ )
709
+ timesteps, num_inference_steps = retrieve_timesteps(
710
+ self.scheduler,
711
+ num_inference_steps,
712
+ device,
713
+ timesteps,
714
+ sigmas,
715
+ mu=mu,
716
+ )
717
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
718
+ self._num_timesteps = len(timesteps)
719
+
720
+ # handle guidance
721
+ if self.transformer.config.guidance_embeds:
722
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
723
+ guidance = guidance.expand(latents.shape[0])
724
+ else:
725
+ guidance = None
726
+
727
+ # 6. Denoising loop
728
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
729
+ for i, t in enumerate(timesteps):
730
+ if self.interrupt:
731
+ continue
732
+
733
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
734
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
735
+
736
+ noise_pred = self.transformer(
737
+ hidden_states=latents,
738
+ timestep=timestep / 1000,
739
+ guidance=guidance,
740
+ pooled_projections=pooled_prompt_embeds,
741
+ encoder_hidden_states=prompt_embeds,
742
+ txt_ids=text_ids,
743
+ img_ids=latent_image_ids,
744
+ joint_attention_kwargs=self.joint_attention_kwargs,
745
+ return_dict=False,
746
+ ntk_factor=ntk_factor
747
+ )[0]
748
+
749
+ # compute the previous noisy sample x_t -> x_t-1
750
+ latents_dtype = latents.dtype
751
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
752
+
753
+ if latents.dtype != latents_dtype:
754
+ #if torch.backends.mps.is_available():
755
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
756
+ latents = latents.to(latents_dtype)
757
+
758
+ if callback_on_step_end is not None:
759
+ callback_kwargs = {}
760
+ for k in callback_on_step_end_tensor_inputs:
761
+ callback_kwargs[k] = locals()[k]
762
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
763
+
764
+ latents = callback_outputs.pop("latents", latents)
765
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
766
+
767
+ # call the callback, if provided
768
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
769
+ progress_bar.update()
770
+
771
+ if XLA_AVAILABLE:
772
+ xm.mark_step()
773
+
774
+ if output_type == "latent":
775
+ image = latents
776
+
777
+ else:
778
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
779
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
780
+ image = self.vae.decode(latents, return_dict=False)[0]
781
+ image = self.image_processor.postprocess(image, output_type=output_type)
782
+
783
+ # Offload all models
784
+ self.maybe_free_model_hooks()
785
+
786
+ if not return_dict:
787
+ return (image,)
788
+
789
+ return FluxPipelineOutput(images=image)
transformer_flux.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union, List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor
14
+ )
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
17
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
18
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
19
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from attention_processor import FluxAttnProcessor2_0
22
+
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ @maybe_allow_in_graph
28
+ class FluxSingleTransformerBlock(nn.Module):
29
+ r"""
30
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
31
+
32
+ Reference: https://arxiv.org/abs/2403.03206
33
+
34
+ Parameters:
35
+ dim (`int`): The number of channels in the input and output.
36
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`): The number of channels in each head.
38
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
39
+ processing of `context` conditions.
40
+ """
41
+
42
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
43
+ super().__init__()
44
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
45
+
46
+ self.norm = AdaLayerNormZeroSingle(dim)
47
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
48
+ self.act_mlp = nn.GELU(approximate="tanh")
49
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
50
+
51
+ processor = FluxAttnProcessor2_0()
52
+ self.attn = Attention(
53
+ query_dim=dim,
54
+ cross_attention_dim=None,
55
+ dim_head=attention_head_dim,
56
+ heads=num_attention_heads,
57
+ out_dim=dim,
58
+ bias=True,
59
+ processor=processor,
60
+ qk_norm="rms_norm",
61
+ eps=1e-6,
62
+ pre_only=True,
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.FloatTensor,
68
+ temb: torch.FloatTensor,
69
+ image_rotary_emb=None,
70
+ joint_attention_kwargs=None
71
+ ):
72
+ residual = hidden_states
73
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
74
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
75
+ joint_attention_kwargs = joint_attention_kwargs or {}
76
+ attn_output = self.attn(
77
+ hidden_states=norm_hidden_states,
78
+ image_rotary_emb=image_rotary_emb,
79
+ **joint_attention_kwargs,
80
+ )
81
+
82
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
83
+ gate = gate.unsqueeze(1)
84
+ hidden_states = gate * self.proj_out(hidden_states)
85
+ hidden_states = residual + hidden_states
86
+ if hidden_states.dtype == torch.float16:
87
+ hidden_states = hidden_states.clip(-65504, 65504)
88
+
89
+ return hidden_states
90
+
91
+
92
+ @maybe_allow_in_graph
93
+ class FluxTransformerBlock(nn.Module):
94
+ r"""
95
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
96
+
97
+ Reference: https://arxiv.org/abs/2403.03206
98
+
99
+ Parameters:
100
+ dim (`int`): The number of channels in the input and output.
101
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
102
+ attention_head_dim (`int`): The number of channels in each head.
103
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
104
+ processing of `context` conditions.
105
+ """
106
+
107
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
108
+ super().__init__()
109
+
110
+ self.norm1 = AdaLayerNormZero(dim)
111
+
112
+ self.norm1_context = AdaLayerNormZero(dim)
113
+
114
+ if hasattr(F, "scaled_dot_product_attention"):
115
+ processor = FluxAttnProcessor2_0()
116
+ else:
117
+ raise ValueError(
118
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
119
+ )
120
+ self.attn = Attention(
121
+ query_dim=dim,
122
+ cross_attention_dim=None,
123
+ added_kv_proj_dim=dim,
124
+ dim_head=attention_head_dim,
125
+ heads=num_attention_heads,
126
+ out_dim=dim,
127
+ context_pre_only=False,
128
+ bias=True,
129
+ processor=processor,
130
+ qk_norm=qk_norm,
131
+ eps=eps,
132
+ )
133
+
134
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
135
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
136
+
137
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.FloatTensor,
147
+ encoder_hidden_states: torch.FloatTensor,
148
+ temb: torch.FloatTensor,
149
+ image_rotary_emb=None,
150
+ joint_attention_kwargs=None,
151
+ ):
152
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
153
+
154
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
155
+ encoder_hidden_states, emb=temb
156
+ )
157
+ joint_attention_kwargs = joint_attention_kwargs or {}
158
+ # Attention.
159
+ attn_output, context_attn_output = self.attn(
160
+ hidden_states=norm_hidden_states,
161
+ encoder_hidden_states=norm_encoder_hidden_states,
162
+ image_rotary_emb=image_rotary_emb,
163
+ **joint_attention_kwargs,
164
+ )
165
+
166
+ # Process attention outputs for the `hidden_states`.
167
+ attn_output = gate_msa.unsqueeze(1) * attn_output
168
+ hidden_states = hidden_states + attn_output
169
+
170
+ norm_hidden_states = self.norm2(hidden_states)
171
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
172
+
173
+ ff_output = self.ff(norm_hidden_states)
174
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
175
+
176
+ hidden_states = hidden_states + ff_output
177
+
178
+ # Process attention outputs for the `encoder_hidden_states`.
179
+
180
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
181
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
182
+
183
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
184
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
185
+
186
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
187
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
188
+ if encoder_hidden_states.dtype == torch.float16:
189
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
190
+
191
+ return encoder_hidden_states, hidden_states
192
+
193
+
194
+ class FluxPosEmbed(nn.Module):
195
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
196
+ def __init__(self, theta: int, axes_dim: List[int]):
197
+ super().__init__()
198
+ self.theta = theta
199
+ self.axes_dim = axes_dim
200
+
201
+ def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor:
202
+ n_axes = ids.shape[-1]
203
+ cos_out = []
204
+ sin_out = []
205
+ pos = ids.float()
206
+ is_mps = ids.device.type == "mps"
207
+ freqs_dtype = torch.float32 if is_mps else torch.float64
208
+ for i in range(n_axes):
209
+ cos, sin = get_1d_rotary_pos_embed(
210
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype,
211
+ ntk_factor=ntk_factor
212
+ )
213
+ cos_out.append(cos)
214
+ sin_out.append(sin)
215
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
216
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
217
+ return freqs_cos, freqs_sin
218
+
219
+
220
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
221
+ """
222
+ The Transformer model introduced in Flux.
223
+
224
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
225
+
226
+ Parameters:
227
+ patch_size (`int`): Patch size to turn the input data into small patches.
228
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
229
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
230
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
231
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
232
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
233
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
234
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
235
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
236
+ """
237
+
238
+ _supports_gradient_checkpointing = True
239
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
240
+
241
+ @register_to_config
242
+ def __init__(
243
+ self,
244
+ patch_size: int = 1,
245
+ in_channels: int = 64,
246
+ num_layers: int = 19,
247
+ num_single_layers: int = 38,
248
+ attention_head_dim: int = 128,
249
+ num_attention_heads: int = 24,
250
+ joint_attention_dim: int = 4096,
251
+ pooled_projection_dim: int = 768,
252
+ guidance_embeds: bool = False,
253
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
254
+ ):
255
+ super().__init__()
256
+ self.out_channels = in_channels
257
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
258
+
259
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
260
+
261
+ text_time_guidance_cls = (
262
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
263
+ )
264
+ self.time_text_embed = text_time_guidance_cls(
265
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
266
+ )
267
+
268
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
269
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
270
+
271
+ self.transformer_blocks = nn.ModuleList(
272
+ [
273
+ FluxTransformerBlock(
274
+ dim=self.inner_dim,
275
+ num_attention_heads=self.config.num_attention_heads,
276
+ attention_head_dim=self.config.attention_head_dim,
277
+ )
278
+ for i in range(self.config.num_layers)
279
+ ]
280
+ )
281
+
282
+ self.single_transformer_blocks = nn.ModuleList(
283
+ [
284
+ FluxSingleTransformerBlock(
285
+ dim=self.inner_dim,
286
+ num_attention_heads=self.config.num_attention_heads,
287
+ attention_head_dim=self.config.attention_head_dim,
288
+ )
289
+ for i in range(self.config.num_single_layers)
290
+ ]
291
+ )
292
+
293
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
294
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
295
+
296
+ self.gradient_checkpointing = False
297
+
298
+ @property
299
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
300
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
301
+ r"""
302
+ Returns:
303
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
304
+ indexed by its weight name.
305
+ """
306
+ # set recursively
307
+ processors = {}
308
+
309
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
310
+ if hasattr(module, "get_processor"):
311
+ processors[f"{name}.processor"] = module.get_processor()
312
+
313
+ for sub_name, child in module.named_children():
314
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
315
+
316
+ return processors
317
+
318
+ for name, module in self.named_children():
319
+ fn_recursive_add_processors(name, module, processors)
320
+
321
+ return processors
322
+
323
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
324
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
325
+ r"""
326
+ Sets the attention processor to use to compute attention.
327
+
328
+ Parameters:
329
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
330
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
331
+ for **all** `Attention` layers.
332
+
333
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
334
+ processor. This is strongly recommended when setting trainable attention processors.
335
+
336
+ """
337
+ count = len(self.attn_processors.keys())
338
+
339
+ if isinstance(processor, dict) and len(processor) != count:
340
+ raise ValueError(
341
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
342
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
343
+ )
344
+
345
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
346
+ if hasattr(module, "set_processor"):
347
+ if not isinstance(processor, dict):
348
+ module.set_processor(processor)
349
+ else:
350
+ module.set_processor(processor.pop(f"{name}.processor"))
351
+
352
+ for sub_name, child in module.named_children():
353
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
354
+
355
+ for name, module in self.named_children():
356
+ fn_recursive_attn_processor(name, module, processor)
357
+
358
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
359
+ def unfuse_qkv_projections(self):
360
+ """Disables the fused QKV projection if enabled.
361
+
362
+ <Tip warning={true}>
363
+
364
+ This API is 🧪 experimental.
365
+
366
+ </Tip>
367
+
368
+ """
369
+ if self.original_attn_processors is not None:
370
+ self.set_attn_processor(self.original_attn_processors)
371
+
372
+ def _set_gradient_checkpointing(self, module, value=False):
373
+ if hasattr(module, "gradient_checkpointing"):
374
+ module.gradient_checkpointing = value
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ encoder_hidden_states: torch.Tensor = None,
380
+ pooled_projections: torch.Tensor = None,
381
+ timestep: torch.LongTensor = None,
382
+ img_ids: torch.Tensor = None,
383
+ txt_ids: torch.Tensor = None,
384
+ guidance: torch.Tensor = None,
385
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
386
+ controlnet_block_samples=None,
387
+ controlnet_single_block_samples=None,
388
+ return_dict: bool = True,
389
+ ntk_factor: float = 1,
390
+ controlnet_blocks_repeat: bool = False,
391
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
392
+ """
393
+ The [`FluxTransformer2DModel`] forward method.
394
+
395
+ Args:
396
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
397
+ Input `hidden_states`.
398
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
399
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
400
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
401
+ from the embeddings of input conditions.
402
+ timestep ( `torch.LongTensor`):
403
+ Used to indicate denoising step.
404
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
405
+ A list of tensors that if specified are added to the residuals of transformer blocks.
406
+ joint_attention_kwargs (`dict`, *optional*):
407
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
408
+ `self.processor` in
409
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
410
+ return_dict (`bool`, *optional*, defaults to `True`):
411
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
412
+ tuple.
413
+
414
+ Returns:
415
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
416
+ `tuple` where the first element is the sample tensor.
417
+ """
418
+
419
+ if txt_ids.ndim == 3:
420
+ logger.warning(
421
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
422
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
423
+ )
424
+ txt_ids = txt_ids[0]
425
+ if img_ids.ndim == 3:
426
+ logger.warning(
427
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
428
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
429
+ )
430
+ img_ids = img_ids[0]
431
+
432
+ if joint_attention_kwargs is not None:
433
+ joint_attention_kwargs = joint_attention_kwargs.copy()
434
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
435
+ else:
436
+ lora_scale = 1.0
437
+
438
+ if USE_PEFT_BACKEND:
439
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
440
+ scale_lora_layers(self, lora_scale)
441
+ else:
442
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
443
+ logger.warning(
444
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
445
+ )
446
+ hidden_states = self.x_embedder(hidden_states)
447
+
448
+ timestep = timestep.to(hidden_states.dtype) * 1000
449
+ if guidance is not None:
450
+ guidance = guidance.to(hidden_states.dtype) * 1000
451
+ else:
452
+ guidance = None
453
+ temb = (
454
+ self.time_text_embed(timestep, pooled_projections)
455
+ if guidance is None
456
+ else self.time_text_embed(timestep, guidance, pooled_projections)
457
+ )
458
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459
+
460
+ ids = torch.cat((txt_ids, img_ids), dim=0)
461
+ image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor)
462
+
463
+ for index_block, block in enumerate(self.transformer_blocks):
464
+ if self.training and self.gradient_checkpointing:
465
+
466
+ def create_custom_forward(module, return_dict=None):
467
+ def custom_forward(*inputs):
468
+ if return_dict is not None:
469
+ return module(*inputs, return_dict=return_dict)
470
+ else:
471
+ return module(*inputs)
472
+
473
+ return custom_forward
474
+
475
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
476
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
477
+ create_custom_forward(block),
478
+ hidden_states,
479
+ encoder_hidden_states,
480
+ temb,
481
+ image_rotary_emb,
482
+ joint_attention_kwargs,
483
+ **ckpt_kwargs,
484
+ )
485
+
486
+ else:
487
+ encoder_hidden_states, hidden_states = block(
488
+ hidden_states=hidden_states,
489
+ encoder_hidden_states=encoder_hidden_states,
490
+ temb=temb,
491
+ image_rotary_emb=image_rotary_emb,
492
+ joint_attention_kwargs=joint_attention_kwargs,
493
+ )
494
+
495
+ # controlnet residual
496
+ if controlnet_block_samples is not None:
497
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
498
+ interval_control = int(np.ceil(interval_control))
499
+ # For Xlabs ControlNet.
500
+ if controlnet_blocks_repeat:
501
+ hidden_states = (
502
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
503
+ )
504
+ else:
505
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
506
+
507
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
508
+
509
+ for index_block, block in enumerate(self.single_transformer_blocks):
510
+ if self.training and self.gradient_checkpointing:
511
+
512
+ def create_custom_forward(module, return_dict=None):
513
+ def custom_forward(*inputs):
514
+ if return_dict is not None:
515
+ return module(*inputs, return_dict=return_dict)
516
+ else:
517
+ return module(*inputs)
518
+
519
+ return custom_forward
520
+
521
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
522
+ hidden_states = torch.utils.checkpoint.checkpoint(
523
+ create_custom_forward(block),
524
+ hidden_states,
525
+ temb,
526
+ image_rotary_emb,
527
+ joint_attention_kwargs,
528
+ **ckpt_kwargs,
529
+ )
530
+
531
+ else:
532
+ hidden_states = block(
533
+ hidden_states=hidden_states,
534
+ temb=temb,
535
+ image_rotary_emb=image_rotary_emb,
536
+ joint_attention_kwargs=joint_attention_kwargs,
537
+ )
538
+
539
+ # controlnet residual
540
+ if controlnet_single_block_samples is not None:
541
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
542
+ interval_control = int(np.ceil(interval_control))
543
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
544
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
545
+ + controlnet_single_block_samples[index_block // interval_control]
546
+ )
547
+
548
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
549
+
550
+ hidden_states = self.norm_out(hidden_states, temb)
551
+ output = self.proj_out(hidden_states)
552
+
553
+ if USE_PEFT_BACKEND:
554
+ # remove `lora_scale` from each PEFT layer
555
+ unscale_lora_layers(self, lora_scale)
556
+
557
+ if not return_dict:
558
+ return (output,)
559
+
560
+ return Transformer2DModelOutput(sample=output)