mpatel57 commited on
Commit
9366016
·
verified ·
1 Parent(s): a030ba0

Update fluxcombined.py

Browse files
Files changed (1) hide show
  1. fluxcombined.py +5 -283
fluxcombined.py CHANGED
@@ -874,7 +874,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
874
  # initialize the random noise for denoising
875
  latents = random_latents.clone().detach()
876
 
877
- self.vae = self.vae.to(torch.float32)
878
 
879
  # 9. Denoising loop
880
  self.transformer.eval()
@@ -959,7 +959,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
959
  else:
960
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
961
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
962
- image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
963
  image = self.image_processor.postprocess(image, output_type=output_type)
964
 
965
  # Offload all models
@@ -973,7 +973,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
973
  def get_diff_image(self, latents):
974
  latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
975
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
976
- image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
977
  image = self.image_processor.postprocess(image, output_type="pt")
978
  return image
979
 
@@ -983,7 +983,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
983
  img = img.resize((512, 512))
984
  return custom_image_processor(img).unsqueeze(0).to(device)
985
 
986
-
987
  @torch.no_grad()
988
  def edit(
989
  self,
@@ -1019,283 +1018,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
1019
  mask_image=None,
1020
  source_steps=1,
1021
  ):
1022
-
1023
- height = height or self.default_sample_size * self.vae_scale_factor
1024
- width = width or self.default_sample_size * self.vae_scale_factor
1025
-
1026
- # 1. Check inputs. Raise error if not correct
1027
- self.check_inputs(
1028
- prompt,
1029
- prompt_2,
1030
- height,
1031
- width,
1032
- # negative_prompt=negative_prompt,
1033
- # negative_prompt_2=negative_prompt_2,
1034
- prompt_embeds=prompt_embeds,
1035
- # negative_prompt_embeds=negative_prompt_embeds,
1036
- pooled_prompt_embeds=pooled_prompt_embeds,
1037
- # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1038
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1039
- max_sequence_length=max_sequence_length,
1040
- )
1041
-
1042
- self._guidance_scale = guidance_scale
1043
- self._joint_attention_kwargs = joint_attention_kwargs
1044
- self._interrupt = False
1045
-
1046
- # 2. Define call parameters
1047
- if prompt is not None and isinstance(prompt, str):
1048
- batch_size = 1
1049
- elif prompt is not None and isinstance(prompt, list):
1050
- batch_size = len(prompt)
1051
- else:
1052
- batch_size = prompt_embeds.shape[0]
1053
-
1054
- device = self._execution_device
1055
-
1056
- lora_scale = (
1057
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1058
- )
1059
- do_true_cfg = true_cfg > 1 and negative_prompt is not None
1060
- (
1061
- prompt_embeds,
1062
- pooled_prompt_embeds,
1063
- text_ids,
1064
- negative_prompt_embeds,
1065
- negative_pooled_prompt_embeds,
1066
- ) = self.encode_prompt_edit(
1067
- prompt=prompt,
1068
- prompt_2=prompt_2,
1069
- negative_prompt=negative_prompt,
1070
- negative_prompt_2=negative_prompt_2,
1071
- prompt_embeds=prompt_embeds,
1072
- pooled_prompt_embeds=pooled_prompt_embeds,
1073
- negative_prompt_embeds=negative_prompt_embeds,
1074
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1075
- device=device,
1076
- num_images_per_prompt=num_images_per_prompt,
1077
- max_sequence_length=max_sequence_length,
1078
- lora_scale=lora_scale,
1079
- do_true_cfg=do_true_cfg,
1080
- )
1081
- # text_ids = text_ids.repeat(batch_size, 1, 1)
1082
-
1083
- if do_true_cfg:
1084
- # Concatenate embeddings
1085
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1086
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1087
-
1088
- # 4. Prepare latent variables
1089
- num_channels_latents = self.transformer.config.in_channels // 4
1090
- random_latents, latent_image_ids = self.prepare_latents(
1091
- batch_size * num_images_per_prompt,
1092
- num_channels_latents,
1093
- height,
1094
- width,
1095
- prompt_embeds.dtype,
1096
- device,
1097
- generator,
1098
- latents,
1099
- )
1100
- # latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1)
1101
-
1102
- # 5. Prepare timesteps
1103
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1104
- image_seq_len = random_latents.shape[1]
1105
- mu = calculate_shift(
1106
- image_seq_len,
1107
- self.scheduler.config.base_image_seq_len,
1108
- self.scheduler.config.max_image_seq_len,
1109
- self.scheduler.config.base_shift,
1110
- self.scheduler.config.max_shift,
1111
- )
1112
- timesteps, num_inference_steps = retrieve_timesteps(
1113
- self.scheduler,
1114
- num_inference_steps,
1115
- device,
1116
- timesteps,
1117
- sigmas,
1118
- mu=mu,
1119
- )
1120
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1121
- self._num_timesteps = len(timesteps)
1122
-
1123
- # 4. Preprocess image
1124
- image = self.image_processor.preprocess(input_image)
1125
- image = image.to(device=device, dtype=self.transformer.dtype)
1126
- latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
1127
-
1128
-
1129
- # Convert PIL image to tensor
1130
- if mask_image:
1131
- from torchvision import transforms as TF
1132
-
1133
- h, w = latents.shape[2], latents.shape[3]
1134
- mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
1135
- mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
1136
- mask = (mask > 0.5).float()
1137
- mask = mask.squeeze(0)#.squeeze(0) # Remove the added dimensions
1138
- else:
1139
- mask = torch.ones_like(latents).to(device=device)
1140
-
1141
- print(mask.shape, latents.shape)
1142
-
1143
- bool_mask = mask.unsqueeze(0).unsqueeze(0).expand_as(latents)
1144
- mask=(1-bool_mask*1.0).to(latents.dtype)
1145
-
1146
- masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
1147
- masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1148
-
1149
- source_latents = (latents).clone().detach() # apply the mask and get gt_latents
1150
- source_latents = self._pack_latents(source_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1151
-
1152
- mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1153
-
1154
- # initialize the random noise for denoising
1155
- latents = random_latents.clone().detach()
1156
-
1157
- self.vae = self.vae.to(torch.float32)
1158
-
1159
- # 9. Denoising loop
1160
- self.transformer.eval()
1161
- self.vae.eval()
1162
-
1163
- # 6. Denoising loop
1164
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1165
- for i, t in enumerate(timesteps):
1166
- if self.interrupt:
1167
- continue
1168
-
1169
- latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
1170
-
1171
- # handle guidance
1172
- if self.transformer.config.guidance_embeds:
1173
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1174
- guidance = guidance.expand(latent_model_input.shape[0])
1175
- else:
1176
- guidance = None
1177
-
1178
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1179
- timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
1180
-
1181
- noise_pred = self.transformer(
1182
- hidden_states=latent_model_input,
1183
- timestep=timestep / 1000,
1184
- guidance=guidance,
1185
- pooled_projections=pooled_prompt_embeds,
1186
- encoder_hidden_states=prompt_embeds,
1187
- txt_ids=text_ids,
1188
- img_ids=latent_image_ids,
1189
- joint_attention_kwargs=self.joint_attention_kwargs,
1190
- return_dict=False,
1191
- )[0]
1192
-
1193
- if do_true_cfg:
1194
- neg_noise_pred, noise_pred = noise_pred.chunk(2)
1195
- # noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
1196
- noise_pred = noise_pred + (1-mask)*(noise_pred - neg_noise_pred) * true_cfg
1197
- # else:
1198
- # neg_noise_pred, noise_pred = noise_pred.chunk(2)
1199
-
1200
- # perform CG
1201
- if i < max_steps:
1202
- opt_latents = latents.detach().clone()
1203
- with torch.enable_grad():
1204
- opt_latents = opt_latents.detach().requires_grad_()
1205
- opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
1206
- # optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
1207
-
1208
- for _ in range(optimization_steps):
1209
- latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
1210
- if i < source_steps:
1211
- loss = (1000*torch.nn.functional.mse_loss(latents_p, source_latents, reduction='none')).mean()
1212
- else:
1213
- loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
1214
-
1215
- grad = torch.autograd.grad(loss, opt_latents)[0]
1216
- # grad = torch.clamp(grad, -0.5, 0.5)
1217
- opt_latents = opt_latents - learning_rate * grad
1218
-
1219
- latents = opt_latents.detach().clone()
1220
-
1221
-
1222
- # compute the previous noisy sample x_t -> x_t-1
1223
- latents_dtype = latents.dtype
1224
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1225
-
1226
- if latents.dtype != latents_dtype:
1227
- if torch.backends.mps.is_available():
1228
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1229
- latents = latents.to(latents_dtype)
1230
-
1231
- if callback_on_step_end is not None:
1232
- callback_kwargs = {}
1233
- for k in callback_on_step_end_tensor_inputs:
1234
- callback_kwargs[k] = locals()[k]
1235
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1236
-
1237
- latents = callback_outputs.pop("latents", latents)
1238
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1239
-
1240
- # call the callback, if provided
1241
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1242
- progress_bar.update()
1243
-
1244
- if XLA_AVAILABLE:
1245
- xm.mark_step()
1246
-
1247
- if output_type == "latent":
1248
- image = latents
1249
-
1250
- else:
1251
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1252
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1253
- image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
1254
- image = self.image_processor.postprocess(image, output_type=output_type)
1255
-
1256
- # Offload all models
1257
- self.maybe_free_model_hooks()
1258
-
1259
- if not return_dict:
1260
- return (image,)
1261
-
1262
- return FluxPipelineOutput(images=image)
1263
-
1264
- @torch.no_grad()
1265
- def edit2(
1266
- self,
1267
- prompt: Union[str, List[str]] = None,
1268
- prompt_2: Optional[Union[str, List[str]]] = None,
1269
- negative_prompt: Union[str, List[str]] = None, #
1270
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1271
- true_cfg: float = 1.0, #
1272
- height: Optional[int] = None,
1273
- width: Optional[int] = None,
1274
- num_inference_steps: int = 28,
1275
- timesteps: List[int] = None,
1276
- guidance_scale: float = 3.5,
1277
- num_images_per_prompt: Optional[int] = 1,
1278
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1279
- latents: Optional[torch.FloatTensor] = None,
1280
- prompt_embeds: Optional[torch.FloatTensor] = None,
1281
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1282
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1283
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1284
- output_type: Optional[str] = "pil",
1285
- return_dict: bool = True,
1286
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1287
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1288
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1289
- max_sequence_length: int = 512,
1290
- optimization_steps: int = 3,
1291
- learning_rate: float = 0.8,
1292
- max_steps: int = 5,
1293
- input_image = None,
1294
- save_masked_image = False,
1295
- output_path="",
1296
- mask_image=None,
1297
- source_steps=1,
1298
- ):
1299
  r"""
1300
  Function invoked when calling the pipeline for generation.
1301
 
@@ -1498,7 +1220,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
1498
  # initialize the random noise for denoising
1499
  latents = random_latents.clone().detach()
1500
 
1501
- self.vae = self.vae.to(torch.float32)
1502
 
1503
  # 9. Denoising loop
1504
  self.transformer.eval()
@@ -1594,7 +1316,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
1594
  else:
1595
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1596
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1597
- image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
1598
  image = self.image_processor.postprocess(image, output_type=output_type)
1599
 
1600
  # Offload all models
 
874
  # initialize the random noise for denoising
875
  latents = random_latents.clone().detach()
876
 
877
+ # self.vae = self.vae.to(torch.float32)
878
 
879
  # 9. Denoising loop
880
  self.transformer.eval()
 
959
  else:
960
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
961
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
962
+ image = self.vae.decode(latents, return_dict=False)[0]
963
  image = self.image_processor.postprocess(image, output_type=output_type)
964
 
965
  # Offload all models
 
973
  def get_diff_image(self, latents):
974
  latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
975
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
976
+ image = self.vae.decode(latents, return_dict=False)[0]
977
  image = self.image_processor.postprocess(image, output_type="pt")
978
  return image
979
 
 
983
  img = img.resize((512, 512))
984
  return custom_image_processor(img).unsqueeze(0).to(device)
985
 
 
986
  @torch.no_grad()
987
  def edit(
988
  self,
 
1018
  mask_image=None,
1019
  source_steps=1,
1020
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  r"""
1022
  Function invoked when calling the pipeline for generation.
1023
 
 
1220
  # initialize the random noise for denoising
1221
  latents = random_latents.clone().detach()
1222
 
1223
+ # self.vae = self.vae.to(torch.float32)
1224
 
1225
  # 9. Denoising loop
1226
  self.transformer.eval()
 
1316
  else:
1317
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1318
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1319
+ image = self.vae.decode(latents, return_dict=False)[0]
1320
  image = self.image_processor.postprocess(image, output_type=output_type)
1321
 
1322
  # Offload all models