eagleswim commited on
Commit
df0e1bf
·
verified ·
1 Parent(s): c870c60

Update custome_pipeline.py

Browse files
Files changed (1) hide show
  1. custome_pipeline.py +1 -168
custome_pipeline.py CHANGED
@@ -1,168 +1 @@
1
- import torch
2
- import numpy as np
3
- from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
- from PIL import Image
6
-
7
- # Constants for shift calculation
8
- BASE_SEQ_LEN = 256
9
- MAX_SEQ_LEN = 4096
10
- BASE_SHIFT = 0.5
11
- MAX_SHIFT = 1.2
12
-
13
- # Helper functions
14
- def calculate_timestep_shift(image_seq_len: int) -> float:
15
- """Calculates the timestep shift (mu) based on the image sequence length."""
16
- m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
17
- b = BASE_SHIFT - m * BASE_SEQ_LEN
18
- mu = image_seq_len * m + b
19
- return mu
20
-
21
- def prepare_timesteps(
22
- scheduler: FlowMatchEulerDiscreteScheduler,
23
- num_inference_steps: Optional[int] = None,
24
- device: Optional[Union[str, torch.device]] = None,
25
- timesteps: Optional[List[int]] = None,
26
- sigmas: Optional[List[float]] = None,
27
- mu: Optional[float] = None,
28
- ) -> (torch.Tensor, int):
29
- """Prepares the timesteps for the diffusion process."""
30
- if timesteps is not None and sigmas is not None:
31
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
32
-
33
- if timesteps is not None:
34
- scheduler.set_timesteps(timesteps=timesteps, device=device)
35
- elif sigmas is not None:
36
- scheduler.set_timesteps(sigmas=sigmas, device=device)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
39
-
40
- timesteps = scheduler.timesteps
41
- num_inference_steps = len(timesteps)
42
- return timesteps, num_inference_steps
43
-
44
- # FLUX pipeline function
45
- class FluxWithCFGPipeline(FluxPipeline):
46
- """
47
- Extends the FluxPipeline to yield intermediate images during the denoising process
48
- with progressively increasing resolution for faster generation.
49
- """
50
- @torch.inference_mode()
51
- def generate_images(
52
- self,
53
- prompt: Union[str, List[str]] = None,
54
- prompt_2: Optional[Union[str, List[str]]] = None,
55
- height: Optional[int] = None,
56
- width: Optional[int] = None,
57
- num_inference_steps: int = 4,
58
- timesteps: List[int] = None,
59
- guidance_scale: float = 3.5,
60
- num_images_per_prompt: Optional[int] = 1,
61
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
62
- latents: Optional[torch.FloatTensor] = None,
63
- prompt_embeds: Optional[torch.FloatTensor] = None,
64
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
- output_type: Optional[str] = "pil",
66
- return_dict: bool = True,
67
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
- max_sequence_length: int = 300,
69
- ):
70
- """Generates images and yields intermediate results during the denoising process."""
71
- height = height or self.default_sample_size * self.vae_scale_factor
72
- width = width or self.default_sample_size * self.vae_scale_factor
73
-
74
- # 1. Check inputs
75
- self.check_inputs(
76
- prompt,
77
- prompt_2,
78
- height,
79
- width,
80
- prompt_embeds=prompt_embeds,
81
- pooled_prompt_embeds=pooled_prompt_embeds,
82
- max_sequence_length=max_sequence_length,
83
- )
84
-
85
- self._guidance_scale = guidance_scale
86
- self._joint_attention_kwargs = joint_attention_kwargs
87
- self._interrupt = False
88
-
89
- # 2. Define call parameters
90
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
91
- device = self._execution_device
92
-
93
- # 3. Encode prompt
94
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
95
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
96
- prompt=prompt,
97
- prompt_2=prompt_2,
98
- prompt_embeds=prompt_embeds,
99
- pooled_prompt_embeds=pooled_prompt_embeds,
100
- device=device,
101
- num_images_per_prompt=num_images_per_prompt,
102
- max_sequence_length=max_sequence_length,
103
- lora_scale=lora_scale,
104
- )
105
- # 4. Prepare latent variables
106
- num_channels_latents = self.transformer.config.in_channels // 4
107
- latents, latent_image_ids = self.prepare_latents(
108
- batch_size * num_images_per_prompt,
109
- num_channels_latents,
110
- height,
111
- width,
112
- prompt_embeds.dtype,
113
- device,
114
- generator,
115
- latents,
116
- )
117
- # 5. Prepare timesteps
118
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
119
- image_seq_len = latents.shape[1]
120
- mu = calculate_timestep_shift(image_seq_len)
121
- timesteps, num_inference_steps = prepare_timesteps(
122
- self.scheduler,
123
- num_inference_steps,
124
- device,
125
- timesteps,
126
- sigmas,
127
- mu=mu,
128
- )
129
- self._num_timesteps = len(timesteps)
130
-
131
- # Handle guidance
132
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
-
134
- # 6. Denoising loop
135
- for i, t in enumerate(timesteps):
136
- if self.interrupt:
137
- continue
138
-
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = self.transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
-
153
- # Yield intermediate result
154
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
- torch.cuda.empty_cache()
156
-
157
- # Final image
158
- return self._decode_latents_to_image(latents, height, width, output_type)
159
- self.maybe_free_model_hooks()
160
- torch.cuda.empty_cache()
161
-
162
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
163
- """Decodes the given latents into an image."""
164
- vae = vae or self.vae
165
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
166
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
167
- image = vae.decode(latents, return_dict=False)[0]
168
- return self.image_processor.postprocess(image, output_type=output_type)[0]
 
1
+