jiuface commited on
Commit
35b4517
·
verified ·
1 Parent(s): 9f1c25f

Delete live_preview_helpers.py

Browse files
Files changed (1) hide show
  1. live_preview_helpers.py +0 -166
live_preview_helpers.py DELETED
@@ -1,166 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
-
6
- # Helper functions
7
- def calculate_shift(
8
- image_seq_len,
9
- base_seq_len: int = 256,
10
- max_seq_len: int = 4096,
11
- base_shift: float = 0.5,
12
- max_shift: float = 1.16,
13
- ):
14
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
- b = base_shift - m * base_seq_len
16
- mu = image_seq_len * m + b
17
- return mu
18
-
19
- def retrieve_timesteps(
20
- scheduler,
21
- num_inference_steps: Optional[int] = None,
22
- device: Optional[Union[str, torch.device]] = None,
23
- timesteps: Optional[List[int]] = None,
24
- sigmas: Optional[List[float]] = None,
25
- **kwargs,
26
- ):
27
- if timesteps is not None and sigmas is not None:
28
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
- if timesteps is not None:
30
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
- timesteps = scheduler.timesteps
32
- num_inference_steps = len(timesteps)
33
- elif sigmas is not None:
34
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
- timesteps = scheduler.timesteps
36
- num_inference_steps = len(timesteps)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
- timesteps = scheduler.timesteps
40
- return timesteps, num_inference_steps
41
-
42
- # FLUX pipeline function
43
- @torch.inference_mode()
44
- def flux_pipe_call_that_returns_an_iterable_of_images(
45
- self,
46
- prompt: Union[str, List[str]] = None,
47
- prompt_2: Optional[Union[str, List[str]]] = None,
48
- height: Optional[int] = None,
49
- width: Optional[int] = None,
50
- num_inference_steps: int = 28,
51
- timesteps: List[int] = None,
52
- guidance_scale: float = 3.5,
53
- num_images_per_prompt: Optional[int] = 1,
54
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
- latents: Optional[torch.FloatTensor] = None,
56
- prompt_embeds: Optional[torch.FloatTensor] = None,
57
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
- output_type: Optional[str] = "pil",
59
- return_dict: bool = True,
60
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
- max_sequence_length: int = 512,
62
- good_vae: Optional[Any] = None,
63
- ):
64
- height = height or self.default_sample_size * self.vae_scale_factor
65
- width = width or self.default_sample_size * self.vae_scale_factor
66
-
67
- # 1. Check inputs
68
- self.check_inputs(
69
- prompt,
70
- prompt_2,
71
- height,
72
- width,
73
- prompt_embeds=prompt_embeds,
74
- pooled_prompt_embeds=pooled_prompt_embeds,
75
- max_sequence_length=max_sequence_length,
76
- )
77
-
78
- self._guidance_scale = guidance_scale
79
- self._joint_attention_kwargs = joint_attention_kwargs
80
- self._interrupt = False
81
-
82
- # 2. Define call parameters
83
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
- device = self._execution_device
85
-
86
- # 3. Encode prompt
87
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
- prompt=prompt,
90
- prompt_2=prompt_2,
91
- prompt_embeds=prompt_embeds,
92
- pooled_prompt_embeds=pooled_prompt_embeds,
93
- device=device,
94
- num_images_per_prompt=num_images_per_prompt,
95
- max_sequence_length=max_sequence_length,
96
- lora_scale=lora_scale,
97
- )
98
- # 4. Prepare latent variables
99
- num_channels_latents = self.transformer.config.in_channels // 4
100
- latents, latent_image_ids = self.prepare_latents(
101
- batch_size * num_images_per_prompt,
102
- num_channels_latents,
103
- height,
104
- width,
105
- prompt_embeds.dtype,
106
- device,
107
- generator,
108
- latents,
109
- )
110
- # 5. Prepare timesteps
111
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
- image_seq_len = latents.shape[1]
113
- mu = calculate_shift(
114
- image_seq_len,
115
- self.scheduler.config.base_image_seq_len,
116
- self.scheduler.config.max_image_seq_len,
117
- self.scheduler.config.base_shift,
118
- self.scheduler.config.max_shift,
119
- )
120
- timesteps, num_inference_steps = retrieve_timesteps(
121
- self.scheduler,
122
- num_inference_steps,
123
- device,
124
- timesteps,
125
- sigmas,
126
- mu=mu,
127
- )
128
- self._num_timesteps = len(timesteps)
129
-
130
- # Handle guidance
131
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
-
133
- # 6. Denoising loop
134
- for i, t in enumerate(timesteps):
135
- if self.interrupt:
136
- continue
137
-
138
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
-
140
- noise_pred = self.transformer(
141
- hidden_states=latents,
142
- timestep=timestep / 1000,
143
- guidance=guidance,
144
- pooled_projections=pooled_prompt_embeds,
145
- encoder_hidden_states=prompt_embeds,
146
- txt_ids=text_ids,
147
- img_ids=latent_image_ids,
148
- joint_attention_kwargs=self.joint_attention_kwargs,
149
- return_dict=False,
150
- )[0]
151
- # Yield intermediate result
152
- latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
153
- latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
154
- image = self.vae.decode(latents_for_image, return_dict=False)[0]
155
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
156
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
157
- torch.cuda.empty_cache()
158
-
159
-
160
- # Final image using good_vae
161
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
- latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
- image = good_vae.decode(latents, return_dict=False)[0]
164
- self.maybe_free_model_hooks()
165
- torch.cuda.empty_cache()
166
- yield self.image_processor.postprocess(image, output_type=output_type)[0]