Spaces:
Sleeping
Sleeping
Update fluxcombined.py
Browse files- fluxcombined.py +5 -283
fluxcombined.py
CHANGED
@@ -874,7 +874,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
874 |
# initialize the random noise for denoising
|
875 |
latents = random_latents.clone().detach()
|
876 |
|
877 |
-
self.vae = self.vae.to(torch.float32)
|
878 |
|
879 |
# 9. Denoising loop
|
880 |
self.transformer.eval()
|
@@ -959,7 +959,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
959 |
else:
|
960 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
961 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
962 |
-
image = self.vae.decode(latents
|
963 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
964 |
|
965 |
# Offload all models
|
@@ -973,7 +973,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
973 |
def get_diff_image(self, latents):
|
974 |
latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
|
975 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
976 |
-
image = self.vae.decode(latents
|
977 |
image = self.image_processor.postprocess(image, output_type="pt")
|
978 |
return image
|
979 |
|
@@ -983,7 +983,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
983 |
img = img.resize((512, 512))
|
984 |
return custom_image_processor(img).unsqueeze(0).to(device)
|
985 |
|
986 |
-
|
987 |
@torch.no_grad()
|
988 |
def edit(
|
989 |
self,
|
@@ -1019,283 +1018,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
1019 |
mask_image=None,
|
1020 |
source_steps=1,
|
1021 |
):
|
1022 |
-
|
1023 |
-
height = height or self.default_sample_size * self.vae_scale_factor
|
1024 |
-
width = width or self.default_sample_size * self.vae_scale_factor
|
1025 |
-
|
1026 |
-
# 1. Check inputs. Raise error if not correct
|
1027 |
-
self.check_inputs(
|
1028 |
-
prompt,
|
1029 |
-
prompt_2,
|
1030 |
-
height,
|
1031 |
-
width,
|
1032 |
-
# negative_prompt=negative_prompt,
|
1033 |
-
# negative_prompt_2=negative_prompt_2,
|
1034 |
-
prompt_embeds=prompt_embeds,
|
1035 |
-
# negative_prompt_embeds=negative_prompt_embeds,
|
1036 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
1037 |
-
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1038 |
-
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
1039 |
-
max_sequence_length=max_sequence_length,
|
1040 |
-
)
|
1041 |
-
|
1042 |
-
self._guidance_scale = guidance_scale
|
1043 |
-
self._joint_attention_kwargs = joint_attention_kwargs
|
1044 |
-
self._interrupt = False
|
1045 |
-
|
1046 |
-
# 2. Define call parameters
|
1047 |
-
if prompt is not None and isinstance(prompt, str):
|
1048 |
-
batch_size = 1
|
1049 |
-
elif prompt is not None and isinstance(prompt, list):
|
1050 |
-
batch_size = len(prompt)
|
1051 |
-
else:
|
1052 |
-
batch_size = prompt_embeds.shape[0]
|
1053 |
-
|
1054 |
-
device = self._execution_device
|
1055 |
-
|
1056 |
-
lora_scale = (
|
1057 |
-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
1058 |
-
)
|
1059 |
-
do_true_cfg = true_cfg > 1 and negative_prompt is not None
|
1060 |
-
(
|
1061 |
-
prompt_embeds,
|
1062 |
-
pooled_prompt_embeds,
|
1063 |
-
text_ids,
|
1064 |
-
negative_prompt_embeds,
|
1065 |
-
negative_pooled_prompt_embeds,
|
1066 |
-
) = self.encode_prompt_edit(
|
1067 |
-
prompt=prompt,
|
1068 |
-
prompt_2=prompt_2,
|
1069 |
-
negative_prompt=negative_prompt,
|
1070 |
-
negative_prompt_2=negative_prompt_2,
|
1071 |
-
prompt_embeds=prompt_embeds,
|
1072 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
1073 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
1074 |
-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1075 |
-
device=device,
|
1076 |
-
num_images_per_prompt=num_images_per_prompt,
|
1077 |
-
max_sequence_length=max_sequence_length,
|
1078 |
-
lora_scale=lora_scale,
|
1079 |
-
do_true_cfg=do_true_cfg,
|
1080 |
-
)
|
1081 |
-
# text_ids = text_ids.repeat(batch_size, 1, 1)
|
1082 |
-
|
1083 |
-
if do_true_cfg:
|
1084 |
-
# Concatenate embeddings
|
1085 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1086 |
-
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1087 |
-
|
1088 |
-
# 4. Prepare latent variables
|
1089 |
-
num_channels_latents = self.transformer.config.in_channels // 4
|
1090 |
-
random_latents, latent_image_ids = self.prepare_latents(
|
1091 |
-
batch_size * num_images_per_prompt,
|
1092 |
-
num_channels_latents,
|
1093 |
-
height,
|
1094 |
-
width,
|
1095 |
-
prompt_embeds.dtype,
|
1096 |
-
device,
|
1097 |
-
generator,
|
1098 |
-
latents,
|
1099 |
-
)
|
1100 |
-
# latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1)
|
1101 |
-
|
1102 |
-
# 5. Prepare timesteps
|
1103 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
1104 |
-
image_seq_len = random_latents.shape[1]
|
1105 |
-
mu = calculate_shift(
|
1106 |
-
image_seq_len,
|
1107 |
-
self.scheduler.config.base_image_seq_len,
|
1108 |
-
self.scheduler.config.max_image_seq_len,
|
1109 |
-
self.scheduler.config.base_shift,
|
1110 |
-
self.scheduler.config.max_shift,
|
1111 |
-
)
|
1112 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
1113 |
-
self.scheduler,
|
1114 |
-
num_inference_steps,
|
1115 |
-
device,
|
1116 |
-
timesteps,
|
1117 |
-
sigmas,
|
1118 |
-
mu=mu,
|
1119 |
-
)
|
1120 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1121 |
-
self._num_timesteps = len(timesteps)
|
1122 |
-
|
1123 |
-
# 4. Preprocess image
|
1124 |
-
image = self.image_processor.preprocess(input_image)
|
1125 |
-
image = image.to(device=device, dtype=self.transformer.dtype)
|
1126 |
-
latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
|
1127 |
-
|
1128 |
-
|
1129 |
-
# Convert PIL image to tensor
|
1130 |
-
if mask_image:
|
1131 |
-
from torchvision import transforms as TF
|
1132 |
-
|
1133 |
-
h, w = latents.shape[2], latents.shape[3]
|
1134 |
-
mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
|
1135 |
-
mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
|
1136 |
-
mask = (mask > 0.5).float()
|
1137 |
-
mask = mask.squeeze(0)#.squeeze(0) # Remove the added dimensions
|
1138 |
-
else:
|
1139 |
-
mask = torch.ones_like(latents).to(device=device)
|
1140 |
-
|
1141 |
-
print(mask.shape, latents.shape)
|
1142 |
-
|
1143 |
-
bool_mask = mask.unsqueeze(0).unsqueeze(0).expand_as(latents)
|
1144 |
-
mask=(1-bool_mask*1.0).to(latents.dtype)
|
1145 |
-
|
1146 |
-
masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
|
1147 |
-
masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
1148 |
-
|
1149 |
-
source_latents = (latents).clone().detach() # apply the mask and get gt_latents
|
1150 |
-
source_latents = self._pack_latents(source_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
1151 |
-
|
1152 |
-
mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
1153 |
-
|
1154 |
-
# initialize the random noise for denoising
|
1155 |
-
latents = random_latents.clone().detach()
|
1156 |
-
|
1157 |
-
self.vae = self.vae.to(torch.float32)
|
1158 |
-
|
1159 |
-
# 9. Denoising loop
|
1160 |
-
self.transformer.eval()
|
1161 |
-
self.vae.eval()
|
1162 |
-
|
1163 |
-
# 6. Denoising loop
|
1164 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1165 |
-
for i, t in enumerate(timesteps):
|
1166 |
-
if self.interrupt:
|
1167 |
-
continue
|
1168 |
-
|
1169 |
-
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
|
1170 |
-
|
1171 |
-
# handle guidance
|
1172 |
-
if self.transformer.config.guidance_embeds:
|
1173 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
1174 |
-
guidance = guidance.expand(latent_model_input.shape[0])
|
1175 |
-
else:
|
1176 |
-
guidance = None
|
1177 |
-
|
1178 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1179 |
-
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
1180 |
-
|
1181 |
-
noise_pred = self.transformer(
|
1182 |
-
hidden_states=latent_model_input,
|
1183 |
-
timestep=timestep / 1000,
|
1184 |
-
guidance=guidance,
|
1185 |
-
pooled_projections=pooled_prompt_embeds,
|
1186 |
-
encoder_hidden_states=prompt_embeds,
|
1187 |
-
txt_ids=text_ids,
|
1188 |
-
img_ids=latent_image_ids,
|
1189 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
1190 |
-
return_dict=False,
|
1191 |
-
)[0]
|
1192 |
-
|
1193 |
-
if do_true_cfg:
|
1194 |
-
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
1195 |
-
# noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
|
1196 |
-
noise_pred = noise_pred + (1-mask)*(noise_pred - neg_noise_pred) * true_cfg
|
1197 |
-
# else:
|
1198 |
-
# neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
1199 |
-
|
1200 |
-
# perform CG
|
1201 |
-
if i < max_steps:
|
1202 |
-
opt_latents = latents.detach().clone()
|
1203 |
-
with torch.enable_grad():
|
1204 |
-
opt_latents = opt_latents.detach().requires_grad_()
|
1205 |
-
opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
|
1206 |
-
# optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
|
1207 |
-
|
1208 |
-
for _ in range(optimization_steps):
|
1209 |
-
latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
|
1210 |
-
if i < source_steps:
|
1211 |
-
loss = (1000*torch.nn.functional.mse_loss(latents_p, source_latents, reduction='none')).mean()
|
1212 |
-
else:
|
1213 |
-
loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
|
1214 |
-
|
1215 |
-
grad = torch.autograd.grad(loss, opt_latents)[0]
|
1216 |
-
# grad = torch.clamp(grad, -0.5, 0.5)
|
1217 |
-
opt_latents = opt_latents - learning_rate * grad
|
1218 |
-
|
1219 |
-
latents = opt_latents.detach().clone()
|
1220 |
-
|
1221 |
-
|
1222 |
-
# compute the previous noisy sample x_t -> x_t-1
|
1223 |
-
latents_dtype = latents.dtype
|
1224 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
1225 |
-
|
1226 |
-
if latents.dtype != latents_dtype:
|
1227 |
-
if torch.backends.mps.is_available():
|
1228 |
-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1229 |
-
latents = latents.to(latents_dtype)
|
1230 |
-
|
1231 |
-
if callback_on_step_end is not None:
|
1232 |
-
callback_kwargs = {}
|
1233 |
-
for k in callback_on_step_end_tensor_inputs:
|
1234 |
-
callback_kwargs[k] = locals()[k]
|
1235 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1236 |
-
|
1237 |
-
latents = callback_outputs.pop("latents", latents)
|
1238 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1239 |
-
|
1240 |
-
# call the callback, if provided
|
1241 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1242 |
-
progress_bar.update()
|
1243 |
-
|
1244 |
-
if XLA_AVAILABLE:
|
1245 |
-
xm.mark_step()
|
1246 |
-
|
1247 |
-
if output_type == "latent":
|
1248 |
-
image = latents
|
1249 |
-
|
1250 |
-
else:
|
1251 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
1252 |
-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1253 |
-
image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
|
1254 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
1255 |
-
|
1256 |
-
# Offload all models
|
1257 |
-
self.maybe_free_model_hooks()
|
1258 |
-
|
1259 |
-
if not return_dict:
|
1260 |
-
return (image,)
|
1261 |
-
|
1262 |
-
return FluxPipelineOutput(images=image)
|
1263 |
-
|
1264 |
-
@torch.no_grad()
|
1265 |
-
def edit2(
|
1266 |
-
self,
|
1267 |
-
prompt: Union[str, List[str]] = None,
|
1268 |
-
prompt_2: Optional[Union[str, List[str]]] = None,
|
1269 |
-
negative_prompt: Union[str, List[str]] = None, #
|
1270 |
-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
1271 |
-
true_cfg: float = 1.0, #
|
1272 |
-
height: Optional[int] = None,
|
1273 |
-
width: Optional[int] = None,
|
1274 |
-
num_inference_steps: int = 28,
|
1275 |
-
timesteps: List[int] = None,
|
1276 |
-
guidance_scale: float = 3.5,
|
1277 |
-
num_images_per_prompt: Optional[int] = 1,
|
1278 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1279 |
-
latents: Optional[torch.FloatTensor] = None,
|
1280 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1281 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1282 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1283 |
-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1284 |
-
output_type: Optional[str] = "pil",
|
1285 |
-
return_dict: bool = True,
|
1286 |
-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1287 |
-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1288 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1289 |
-
max_sequence_length: int = 512,
|
1290 |
-
optimization_steps: int = 3,
|
1291 |
-
learning_rate: float = 0.8,
|
1292 |
-
max_steps: int = 5,
|
1293 |
-
input_image = None,
|
1294 |
-
save_masked_image = False,
|
1295 |
-
output_path="",
|
1296 |
-
mask_image=None,
|
1297 |
-
source_steps=1,
|
1298 |
-
):
|
1299 |
r"""
|
1300 |
Function invoked when calling the pipeline for generation.
|
1301 |
|
@@ -1498,7 +1220,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
1498 |
# initialize the random noise for denoising
|
1499 |
latents = random_latents.clone().detach()
|
1500 |
|
1501 |
-
self.vae = self.vae.to(torch.float32)
|
1502 |
|
1503 |
# 9. Denoising loop
|
1504 |
self.transformer.eval()
|
@@ -1594,7 +1316,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
1594 |
else:
|
1595 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
1596 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1597 |
-
image = self.vae.decode(latents
|
1598 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1599 |
|
1600 |
# Offload all models
|
|
|
874 |
# initialize the random noise for denoising
|
875 |
latents = random_latents.clone().detach()
|
876 |
|
877 |
+
# self.vae = self.vae.to(torch.float32)
|
878 |
|
879 |
# 9. Denoising loop
|
880 |
self.transformer.eval()
|
|
|
959 |
else:
|
960 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
961 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
962 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
963 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
964 |
|
965 |
# Offload all models
|
|
|
973 |
def get_diff_image(self, latents):
|
974 |
latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
|
975 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
976 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
977 |
image = self.image_processor.postprocess(image, output_type="pt")
|
978 |
return image
|
979 |
|
|
|
983 |
img = img.resize((512, 512))
|
984 |
return custom_image_processor(img).unsqueeze(0).to(device)
|
985 |
|
|
|
986 |
@torch.no_grad()
|
987 |
def edit(
|
988 |
self,
|
|
|
1018 |
mask_image=None,
|
1019 |
source_steps=1,
|
1020 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1021 |
r"""
|
1022 |
Function invoked when calling the pipeline for generation.
|
1023 |
|
|
|
1220 |
# initialize the random noise for denoising
|
1221 |
latents = random_latents.clone().detach()
|
1222 |
|
1223 |
+
# self.vae = self.vae.to(torch.float32)
|
1224 |
|
1225 |
# 9. Denoising loop
|
1226 |
self.transformer.eval()
|
|
|
1316 |
else:
|
1317 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
1318 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1319 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1320 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1321 |
|
1322 |
# Offload all models
|