ivand-all3d commited on
Commit
e6ca313
·
1 Parent(s): da1c6e0

Add flux-image-variations

Browse files
flux-image-variations/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxTransformer2DModel",
3
+ "_diffusers_version": "0.31.0",
4
+ "_name_or_path": "output-model",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "guidance_embeds": false,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 19,
16
+ "num_single_layers": 38,
17
+ "patch_size": 1,
18
+ "pooled_projection_dim": 768
19
+ }
flux-image-variations/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6bfb9d8e1fdfa88ef5b6290d4bbb4c4b6c5db1c1d16d35f33153230b6b0554e
3
+ size 47564852120
inference_flux_model.py → inference_flux.py RENAMED
@@ -71,7 +71,7 @@ def main():
71
  global pipe
72
  clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
73
  pipe = FluxWithCFGPipeline.from_pretrained("ostris/OpenFLUX.1", text_encoder=clip, transformer=None, torch_dtype=torch.bfloat16)
74
- pipe.transformer = FluxTransformer2DModel.from_pretrained("flux-image-variations-model", torch_dtype=torch.bfloat16)
75
  pipe.to("cuda")
76
 
77
  img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
@@ -79,4 +79,4 @@ def main():
79
  num_images=args.num_images, resolution=args.resolution)
80
 
81
  if __name__ == "__main__":
82
- main()
 
71
  global pipe
72
  clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
73
  pipe = FluxWithCFGPipeline.from_pretrained("ostris/OpenFLUX.1", text_encoder=clip, transformer=None, torch_dtype=torch.bfloat16)
74
+ pipe.transformer = FluxTransformer2DModel.from_pretrained("flux-image-variations", torch_dtype=torch.bfloat16)
75
  pipe.to("cuda")
76
 
77
  img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
 
79
  num_images=args.num_images, resolution=args.resolution)
80
 
81
  if __name__ == "__main__":
82
+ main()
open_flux_pipeline.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
4
+ from diffusers import FluxPipeline
5
+ from typing import List, Union, Optional, Dict, Any, Callable
6
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
7
+
8
+ from diffusers.utils import is_torch_xla_available
9
+
10
+ from modified_flux import FluxImageConditionedPipeline
11
+
12
+ if is_torch_xla_available():
13
+ import torch_xla.core.xla_model as xm
14
+
15
+ XLA_AVAILABLE = True
16
+ else:
17
+ XLA_AVAILABLE = False
18
+
19
+ # TODO this is rough. Need to properly stack unconditional or make it optional
20
+ class FluxWithCFGPipeline(FluxImageConditionedPipeline):
21
+ def __call__(
22
+ self,
23
+ image_prompt = None,
24
+ prompt: Union[str, List[str]] = None,
25
+ prompt_2: Optional[Union[str, List[str]]] = None,
26
+ negative_image_prompt = None,
27
+ negative_prompt: Optional[Union[str, List[str]]] = None,
28
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
29
+ height: Optional[int] = None,
30
+ width: Optional[int] = None,
31
+ num_inference_steps: int = 28,
32
+ timesteps: List[int] = None,
33
+ guidance_scale: float = 7.0,
34
+ num_images_per_prompt: Optional[int] = 1,
35
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
36
+ latents: Optional[torch.FloatTensor] = None,
37
+ prompt_embeds: Optional[torch.FloatTensor] = None,
38
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
39
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
40
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
41
+ output_type: Optional[str] = "pil",
42
+ return_dict: bool = True,
43
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
44
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
45
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
46
+ max_sequence_length: int = 512,
47
+ ):
48
+
49
+ height = height or self.default_sample_size * self.vae_scale_factor
50
+ width = width or self.default_sample_size * self.vae_scale_factor
51
+
52
+ # 1. Check inputs. Raise error if not correct
53
+ self.check_inputs(
54
+ prompt,
55
+ prompt_2,
56
+ height,
57
+ width,
58
+ prompt_embeds=prompt_embeds,
59
+ pooled_prompt_embeds=pooled_prompt_embeds,
60
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
61
+ max_sequence_length=max_sequence_length,
62
+ )
63
+
64
+ self._guidance_scale = guidance_scale
65
+ self._joint_attention_kwargs = joint_attention_kwargs
66
+ self._interrupt = False
67
+
68
+ # 2. Define call parameters
69
+ if prompt is not None and isinstance(prompt, str):
70
+ batch_size = 1
71
+ elif prompt is not None and isinstance(prompt, list):
72
+ batch_size = len(prompt)
73
+ else:
74
+ batch_size = prompt_embeds.shape[0]
75
+
76
+ device = self._execution_device
77
+
78
+ lora_scale = (
79
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
80
+ )
81
+ (
82
+ prompt_embeds,
83
+ pooled_prompt_embeds,
84
+ text_ids,
85
+ ) = self.encode_prompt(
86
+ image_prompt=image_prompt,
87
+ prompt=prompt,
88
+ prompt_2=prompt_2,
89
+ prompt_embeds=prompt_embeds,
90
+ pooled_prompt_embeds=pooled_prompt_embeds,
91
+ device=device,
92
+ num_images_per_prompt=num_images_per_prompt,
93
+ max_sequence_length=max_sequence_length,
94
+ lora_scale=lora_scale,
95
+ )
96
+ (
97
+ negative_prompt_embeds,
98
+ negative_pooled_prompt_embeds,
99
+ negative_text_ids,
100
+ ) = self.encode_prompt(
101
+ image_prompt=negative_image_prompt,
102
+ prompt=negative_prompt,
103
+ prompt_2=negative_prompt_2,
104
+ prompt_embeds=negative_prompt_embeds,
105
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
106
+ device=device,
107
+ num_images_per_prompt=num_images_per_prompt,
108
+ max_sequence_length=max_sequence_length,
109
+ lora_scale=lora_scale,
110
+ )
111
+
112
+ # 4. Prepare latent variables
113
+ num_channels_latents = self.transformer.config.in_channels // 4
114
+ latents, latent_image_ids = self.prepare_latents(
115
+ batch_size * num_images_per_prompt,
116
+ num_channels_latents,
117
+ height,
118
+ width,
119
+ prompt_embeds.dtype,
120
+ device,
121
+ generator,
122
+ latents,
123
+ )
124
+
125
+ # 4.5. Concat CLIP image embeddings
126
+ # if image_prompt is not None:
127
+ # image_prompt_embeds = self._encode_image_with_clip(image_prompt, device=device)
128
+ # latents_conditioned = torch.cat((latents, image_prompt_embeds.reshape((batch_size, 12, 64))), dim=1)
129
+ # latent_image_ids_conditioned = torch.cat((latent_image_ids, torch.zeros(12, 3).to(device)), dim=0)
130
+ # else:
131
+ # latents_conditioned = latents
132
+ # latent_image_ids_conditioned = latent_image_ids
133
+
134
+ # 5. Prepare timesteps
135
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
136
+ image_seq_len = latents.shape[1]
137
+ mu = calculate_shift(
138
+ image_seq_len,
139
+ self.scheduler.config.base_image_seq_len,
140
+ self.scheduler.config.max_image_seq_len,
141
+ self.scheduler.config.base_shift,
142
+ self.scheduler.config.max_shift,
143
+ )
144
+ timesteps, num_inference_steps = retrieve_timesteps(
145
+ self.scheduler,
146
+ num_inference_steps,
147
+ device,
148
+ timesteps,
149
+ sigmas,
150
+ mu=mu,
151
+ )
152
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
153
+ self._num_timesteps = len(timesteps)
154
+
155
+ # 6. Denoising loop
156
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
157
+ for i, t in enumerate(timesteps):
158
+ if self.interrupt:
159
+ continue
160
+ print(f" - Memory: {torch.cuda.memory_allocated()/(1024**3):.3f} GB")
161
+
162
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
163
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
164
+
165
+ # handle guidance
166
+ if self.transformer.config.guidance_embeds:
167
+ guidance = torch.tensor([guidance_scale], device=device)
168
+ guidance = guidance.expand(latents.shape[0])
169
+ else:
170
+ guidance = None
171
+
172
+ noise_pred_text = self.transformer(
173
+ hidden_states=latents,#_conditioned,
174
+ timestep=timestep / 1000,
175
+ guidance=guidance,
176
+ pooled_projections=pooled_prompt_embeds,
177
+ encoder_hidden_states=prompt_embeds,
178
+ txt_ids=text_ids,
179
+ img_ids=latent_image_ids,#_conditioned,
180
+ joint_attention_kwargs=self.joint_attention_kwargs,
181
+ return_dict=False,
182
+ )[0]
183
+ # if image_prompt is not None:
184
+ # noise_pred_text = noise_pred_text.narrow(1, 0, 1024)
185
+
186
+ # todo combine these
187
+ noise_pred_uncond = self.transformer(
188
+ hidden_states=latents,
189
+ timestep=timestep / 1000,
190
+ guidance=guidance,
191
+ pooled_projections=negative_pooled_prompt_embeds,
192
+ encoder_hidden_states=negative_prompt_embeds,
193
+ txt_ids=negative_text_ids,
194
+ img_ids=latent_image_ids,
195
+ joint_attention_kwargs=self.joint_attention_kwargs,
196
+ return_dict=False,
197
+ )[0]
198
+
199
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
200
+
201
+ # compute the previous noisy sample x_t -> x_t-1
202
+ latents_dtype = latents.dtype
203
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
204
+
205
+ if latents.dtype != latents_dtype:
206
+ if torch.backends.mps.is_available():
207
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
208
+ latents = latents.to(latents_dtype)
209
+
210
+ if callback_on_step_end is not None:
211
+ callback_kwargs = {}
212
+ for k in callback_on_step_end_tensor_inputs:
213
+ callback_kwargs[k] = locals()[k]
214
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
215
+
216
+ latents = callback_outputs.pop("latents", latents)
217
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
218
+ print(f" - Memory: {torch.cuda.memory_allocated()/(1024**3):.3f} GB")
219
+
220
+ # call the callback, if provided
221
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
222
+ progress_bar.update()
223
+
224
+ if XLA_AVAILABLE:
225
+ xm.mark_step()
226
+
227
+ if output_type == "latent":
228
+ image = latents
229
+
230
+ else:
231
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
232
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
233
+ image = self.vae.decode(latents.to(device, dtype=self.vae.dtype), return_dict=False)[0]
234
+ image = self.image_processor.postprocess(image, output_type=output_type)
235
+
236
+ # Offload all models
237
+ self.maybe_free_model_hooks()
238
+
239
+ if not return_dict:
240
+ return (image,)
241
+
242
+ return FluxPipelineOutput(images=image)