multimodalart HF Staff commited on
Commit
0f0144b
·
verified ·
1 Parent(s): b9f2bbb

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +120 -39
clip_slider_pipeline.py CHANGED
@@ -4,6 +4,66 @@ import random
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class CLIPSlider:
9
  def __init__(
@@ -49,9 +109,9 @@ class CLIPSlider:
49
  pos_prompt = f"a {medium} of a {target_word} {subject}"
50
  neg_prompt = f"a {medium} of a {opposite} {subject}"
51
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
52
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
53
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
54
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
55
  pos = self.pipe.text_encoder(pos_toks).pooler_output
56
  neg = self.pipe.text_encoder(neg_toks).pooler_output
57
  positives.append(pos)
@@ -81,7 +141,7 @@ class CLIPSlider:
81
 
82
  with torch.no_grad():
83
  toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
84
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
85
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
86
 
87
  if self.avg_diff_2nd and normalize_scales:
@@ -163,18 +223,18 @@ class CLIPSliderXL(CLIPSlider):
163
  neg_prompt = f"a {medium} of a {opposite} {subject}"
164
 
165
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
166
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
167
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
168
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
169
  pos = self.pipe.text_encoder(pos_toks).pooler_output
170
  neg = self.pipe.text_encoder(neg_toks).pooler_output
171
  positives.append(pos)
172
  negatives.append(neg)
173
 
174
  pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
175
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
176
  neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
177
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
178
  pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
179
  neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
180
  positives2.append(pos2)
@@ -207,7 +267,7 @@ class CLIPSliderXL(CLIPSlider):
207
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
208
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
209
  with torch.no_grad():
210
- # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
211
  # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
212
 
213
  prompt_embeds_list = []
@@ -300,18 +360,18 @@ class CLIPSliderXL_inv(CLIPSlider):
300
  neg_prompt = f"a {medium} of a {opposite} {subject}"
301
 
302
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
303
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
304
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
305
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
306
  pos = self.pipe.text_encoder(pos_toks).pooler_output
307
  neg = self.pipe.text_encoder(neg_toks).pooler_output
308
  positives.append(pos)
309
  negatives.append(neg)
310
 
311
  pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
312
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
313
  neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
314
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
315
  pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
316
  neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
317
  positives2.append(pos2)
@@ -377,14 +437,14 @@ class CLIPSliderFlux(CLIPSlider):
377
  truncation=True,
378
  return_overflowing_tokens=False,
379
  return_length=False,
380
- return_tensors="pt",).input_ids.cuda()
381
  neg_toks = self.pipe.tokenizer(neg_prompt,
382
  padding="max_length",
383
  max_length=self.pipe.tokenizer_max_length,
384
  truncation=True,
385
  return_overflowing_tokens=False,
386
  return_length=False,
387
- return_tensors="pt",).input_ids.cuda()
388
  pos = self.pipe.text_encoder(pos_toks).pooler_output
389
  neg = self.pipe.text_encoder(neg_toks).pooler_output
390
  positives.append(pos)
@@ -400,17 +460,22 @@ class CLIPSliderFlux(CLIPSlider):
400
 
401
  def generate(self,
402
  prompt = "a photo of a house",
403
- scale = 2,
404
- scale_2nd = 2,
405
  seed = 15,
406
  normalize_scales = False,
407
  avg_diff = None,
408
- avg_diff_2nd = None,
 
 
409
  **pipeline_kwargs
410
  ):
411
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
412
  # if pooler token only [-4,4] work well
413
 
 
 
 
 
414
  with torch.no_grad():
415
  text_inputs = self.pipe.tokenizer(
416
  prompt,
@@ -423,15 +488,11 @@ class CLIPSliderFlux(CLIPSlider):
423
  )
424
 
425
  text_input_ids = text_inputs.input_ids
426
- prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
427
-
428
- # Use pooled output of CLIPTextModel
429
- prompt_embeds = prompt_embeds.pooler_output
430
- pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
431
-
432
- # Use pooled output of CLIPTextModel
433
-
434
- text_inputs = self.pipe.tokenizer_2(
435
  prompt,
436
  padding="max_length",
437
  max_length=512,
@@ -440,21 +501,40 @@ class CLIPSliderFlux(CLIPSlider):
440
  return_overflowing_tokens=False,
441
  return_tensors="pt",
442
  )
443
- toks = text_inputs.input_ids
444
- prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
445
- dtype = self.pipe.text_encoder_2.dtype
446
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
447
- if avg_diff_2nd is not None and normalize_scales:
448
- denominator = abs(scale) + abs(scale_2nd)
449
- scale = scale / denominator
450
- scale_2nd = scale_2nd / denominator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
- pooled_prompt_embeds = pooled_prompt_embeds + avg_diff * scale
453
  if avg_diff_2nd is not None:
454
- pooled_prompt_embeds += avg_diff_2nd * scale_2nd
 
455
 
456
  torch.manual_seed(seed)
457
- images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
 
458
  **pipeline_kwargs).images
459
 
460
  return images[0]
@@ -483,6 +563,7 @@ class CLIPSliderFlux(CLIPSlider):
483
  canvas.paste(im, (640 * i, 0))
484
 
485
  return canvas
 
486
  class T5SliderFlux(CLIPSlider):
487
 
488
  def find_latent_direction(self,
@@ -509,14 +590,14 @@ class T5SliderFlux(CLIPSlider):
509
  truncation=True,
510
  return_length=False,
511
  return_overflowing_tokens=False,
512
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
513
  neg_toks = self.pipe.tokenizer_2(neg_prompt,
514
  return_tensors="pt",
515
  padding="max_length",
516
  truncation=True,
517
  return_length=False,
518
  return_overflowing_tokens=False,
519
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
520
  pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
521
  neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
522
  positives.append(pos)
 
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
7
+ import math # For acos, sin
8
+
9
+ # Slerp (Spherical Linear Interpolation) function
10
+ def slerp(v0, v1, t, DOT_THRESHOLD=0.9995):
11
+ """
12
+ Spherical linear interpolation.
13
+ v0, v1: Tensors to interpolate between.
14
+ t: Interpolation factor (scalar or tensor).
15
+ DOT_THRESHOLD: Threshold for considering vectors collinear.
16
+ """
17
+ if not isinstance(t, torch.Tensor):
18
+ t = torch.tensor(t, device=v0.device, dtype=v0.dtype)
19
+
20
+ # Dot product
21
+ dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True)
22
+
23
+ # If vectors are too close, use linear interpolation (LERP)
24
+ # This also handles t=0 and t=1 correctly if dot is 1.
25
+ # Also, if dot is -1 (opposite), omega is pi.
26
+ if torch.any(torch.abs(dot) > DOT_THRESHOLD):
27
+ # For Slerp, if they are too close, omega is small, sin(omega) is small.
28
+ # Fallback to LERP for stability and when vectors are nearly collinear.
29
+ # However, the general Slerp formula handles this if dot is clamped.
30
+ # Let's use the standard formula but ensure stability.
31
+ pass # Continue to Slerp formula with clamping
32
+
33
+ # Clamp dot to prevent NaN from acos due to floating point errors.
34
+ dot = torch.clamp(dot, -1.0, 1.0)
35
+ omega = torch.acos(dot) # Angle between vectors
36
+
37
+ # Get magnitudes for later linear interpolation of magnitude
38
+ mag_v0 = torch.norm(v0, dim=-1, keepdim=True)
39
+ mag_v1 = torch.norm(v1, dim=-1, keepdim=True)
40
+
41
+ interpolated_mag = (1 - t) * mag_v0 + t * mag_v1
42
+
43
+ # Normalize v0 and v1 for pure Slerp on direction
44
+ v0_norm = v0 / (mag_v0 + 1e-8)
45
+ v1_norm = v1 / (mag_v1 + 1e-8)
46
+
47
+ # If sin_omega is very small, vectors are nearly collinear.
48
+ # LERP on normalized vectors is a good approximation.
49
+ # Then re-apply interpolated magnitude.
50
+ sin_omega = torch.sin(omega)
51
+
52
+ # Condition for LERP fallback (nearly collinear)
53
+ # Using a small epsilon for sin_omega
54
+ use_lerp_fallback = sin_omega.abs() < 1e-5
55
+
56
+ s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability
57
+ s1 = torch.sin(t * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability
58
+
59
+ # For elements where LERP fallback is needed
60
+ s0[use_lerp_fallback] = 1.0 - t
61
+ s1[use_lerp_fallback] = t
62
+
63
+ result_norm = s0 * v0_norm + s1 * v1_norm
64
+ result = result_norm * interpolated_mag # Re-apply interpolated magnitude
65
+
66
+ return result.to(v0.dtype)
67
 
68
  class CLIPSlider:
69
  def __init__(
 
109
  pos_prompt = f"a {medium} of a {target_word} {subject}"
110
  neg_prompt = f"a {medium} of a {opposite} {subject}"
111
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
112
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
113
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
114
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
115
  pos = self.pipe.text_encoder(pos_toks).pooler_output
116
  neg = self.pipe.text_encoder(neg_toks).pooler_output
117
  positives.append(pos)
 
141
 
142
  with torch.no_grad():
143
  toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
144
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
145
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
146
 
147
  if self.avg_diff_2nd and normalize_scales:
 
223
  neg_prompt = f"a {medium} of a {opposite} {subject}"
224
 
225
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
226
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
227
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
228
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
229
  pos = self.pipe.text_encoder(pos_toks).pooler_output
230
  neg = self.pipe.text_encoder(neg_toks).pooler_output
231
  positives.append(pos)
232
  negatives.append(neg)
233
 
234
  pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
235
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
236
  neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
237
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
238
  pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
239
  neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
240
  positives2.append(pos2)
 
267
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
268
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
269
  with torch.no_grad():
270
+ # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device)
271
  # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
272
 
273
  prompt_embeds_list = []
 
360
  neg_prompt = f"a {medium} of a {opposite} {subject}"
361
 
362
  pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
363
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
364
  neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
365
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
366
  pos = self.pipe.text_encoder(pos_toks).pooler_output
367
  neg = self.pipe.text_encoder(neg_toks).pooler_output
368
  positives.append(pos)
369
  negatives.append(neg)
370
 
371
  pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
372
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
373
  neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
374
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
375
  pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
376
  neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
377
  positives2.append(pos2)
 
437
  truncation=True,
438
  return_overflowing_tokens=False,
439
  return_length=False,
440
+ return_tensors="pt",).input_ids.to(self.device)
441
  neg_toks = self.pipe.tokenizer(neg_prompt,
442
  padding="max_length",
443
  max_length=self.pipe.tokenizer_max_length,
444
  truncation=True,
445
  return_overflowing_tokens=False,
446
  return_length=False,
447
+ return_tensors="pt",).input_ids.to(self.device)
448
  pos = self.pipe.text_encoder(pos_toks).pooler_output
449
  neg = self.pipe.text_encoder(neg_toks).pooler_output
450
  positives.append(pos)
 
460
 
461
  def generate(self,
462
  prompt = "a photo of a house",
463
+ scale = 2.0,
 
464
  seed = 15,
465
  normalize_scales = False,
466
  avg_diff = None,
467
+ avg_diff_2nd = None,
468
+ use_slerp: bool = False,
469
+ max_strength_for_slerp_endpoint: float = 0.0,
470
  **pipeline_kwargs
471
  ):
472
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
473
  # if pooler token only [-4,4] work well
474
 
475
+ # Remove slider-specific kwargs before passing to the pipeline
476
+ pipeline_kwargs.pop('use_slerp', None)
477
+ pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None)
478
+
479
  with torch.no_grad():
480
  text_inputs = self.pipe.tokenizer(
481
  prompt,
 
488
  )
489
 
490
  text_input_ids = text_inputs.input_ids
491
+ prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
492
+ original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
493
+
494
+ # For the second text encoder (T5-like for FLUX)
495
+ text_inputs_2 = self.pipe.tokenizer_2(
 
 
 
 
496
  prompt,
497
  padding="max_length",
498
  max_length=512,
 
501
  return_overflowing_tokens=False,
502
  return_tensors="pt",
503
  )
504
+ toks_2 = text_inputs_2.input_ids
505
+ # This is the non-pooled, sequence output for the second encoder
506
+ prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0]
507
+ prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device)
508
+
509
+ modified_pooled_embeds = original_pooled_prompt_embeds.clone()
510
+
511
+ if avg_diff is not None:
512
+ if use_slerp and max_strength_for_slerp_endpoint != 0.0:
513
+ # Slerp logic
514
+ slerp_t_val = 0.0
515
+ if max_strength_for_slerp_endpoint != 0:
516
+ slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint
517
+ slerp_t_val = min(slerp_t_val, 1.0)
518
+
519
+ if scale == 0:
520
+ pass
521
+ else:
522
+ v0 = original_pooled_prompt_embeds.float()
523
+ if scale > 0:
524
+ v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff
525
+ else:
526
+ v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff
527
+ modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype)
528
+ else:
529
+ modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale
530
 
 
531
  if avg_diff_2nd is not None:
532
+ scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0)
533
+ modified_pooled_embeds += avg_diff_2nd * scale_2nd_val
534
 
535
  torch.manual_seed(seed)
536
+ images = self.pipe(prompt_embeds=prompt_embeds_seq_2,
537
+ pooled_prompt_embeds=modified_pooled_embeds,
538
  **pipeline_kwargs).images
539
 
540
  return images[0]
 
563
  canvas.paste(im, (640 * i, 0))
564
 
565
  return canvas
566
+
567
  class T5SliderFlux(CLIPSlider):
568
 
569
  def find_latent_direction(self,
 
590
  truncation=True,
591
  return_length=False,
592
  return_overflowing_tokens=False,
593
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
594
  neg_toks = self.pipe.tokenizer_2(neg_prompt,
595
  return_tensors="pt",
596
  padding="max_length",
597
  truncation=True,
598
  return_length=False,
599
  return_overflowing_tokens=False,
600
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
601
  pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
602
  neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
603
  positives.append(pos)