bardofcodes commited on
Commit
ae73afb
·
verified ·
1 Parent(s): 896ee3a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +70 -32
pipeline.py CHANGED
@@ -18,12 +18,7 @@ import einops
18
  import PIL.Image
19
  import numpy as np
20
  import torch as th
21
- import torch.nn as nn
22
- from torchvision import transforms
23
 
24
- from diffusers import ModelMixin
25
- from transformers import AutoModel, AutoConfig, SiglipVisionConfig, Dinov2Config, Dinov2Model
26
- from transformers import SiglipVisionModel
27
  from diffusers import DiffusionPipeline
28
  from diffusers.image_processor import VaeImageProcessor
29
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
@@ -31,8 +26,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
31
  from diffusers.utils.torch_utils import randn_tensor
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
33
 
34
- from diffusers.configuration_utils import ConfigMixin, register_to_config
35
- # REf: https://github.com/tatp22/multidim-positional-encoding/tree/master
36
  from analogy_encoder import AnalogyEncoder
37
  from analogy_projector import AnalogyProjector
38
  from analogy_input_processor import AnalogyInputProcessor
@@ -259,8 +252,8 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
259
  The call function to the pipeline for generation.
260
 
261
  Args:
262
- image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
263
- The image prompt or prompts to guide the image generation.
264
  height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
265
  The height in pixels of the generated image.
266
  width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
@@ -301,32 +294,77 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
301
  Examples:
302
 
303
  ```py
304
- >>> from diffusers import VersatileDiffusionImageVariationPipeline
305
- >>> import torch
306
- >>> import requests
307
- >>> from io import BytesIO
308
- >>> from PIL import Image
309
-
310
- >>> # let's download an initial image
311
- >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
312
-
313
- >>> response = requests.get(url)
314
- >>> image = Image.open(BytesIO(response.content)).convert("RGB")
315
-
316
- >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained(
317
- ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
318
- ... )
319
- >>> pipe = pipe.to("cuda")
320
-
321
- >>> generator = torch.Generator(device="cuda").manual_seed(0)
322
- >>> image = pipe(image, generator=generator).images[0]
323
- >>> image.save("./car_variation.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  ```
325
 
326
  Returns:
327
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
328
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
329
- otherwise a `tuple` is returned where the first element is a list with the generated images.
330
  """
331
 
332
  # 1. Check inputs. Raise error if not correct
 
18
  import PIL.Image
19
  import numpy as np
20
  import torch as th
 
 
21
 
 
 
 
22
  from diffusers import DiffusionPipeline
23
  from diffusers.image_processor import VaeImageProcessor
24
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
 
 
 
29
  from analogy_encoder import AnalogyEncoder
30
  from analogy_projector import AnalogyProjector
31
  from analogy_input_processor import AnalogyInputProcessor
 
252
  The call function to the pipeline for generation.
253
 
254
  Args:
255
+ analogy_prompt (`List[Tuple[PIL.Image.Image]]'):
256
+ The analogy sequence A, A*, B which is our model's prompt for generating B* the analogical pattern satisfying A:A*::B:B*.
257
  height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
258
  The height in pixels of the generated image.
259
  width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
 
294
  Examples:
295
 
296
  ```py
297
+ import requests
298
+ import torch as th
299
+ from PIL import Image
300
+ from io import BytesIO
301
+ import matplotlib.pyplot as plt
302
+ from PIL import Image, ImageOps
303
+ from diffusers import DiffusionPipeline
304
+
305
+ SEED = 1729
306
+ DEVICE = th.device("cuda")
307
+ DTYPE = th.float16
308
+ FIG_K = 3
309
+ EXAMPLE_ID = 0
310
+
311
+ # Now we need to do the trick
312
+ pretrained_path = "bardofcodes/pattern_analogies"
313
+ new_pipe = DiffusionPipeline.from_pretrained(
314
+ pretrained_path,
315
+ custom_pipeline=pretrained_path,
316
+ trust_remote_code=True
317
+ )
318
+
319
+ img_urls = [
320
+ f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a.png",
321
+ f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a_star.png",
322
+ f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_b.png",
323
+ ]
324
+ images = []
325
+ for url in img_urls:
326
+ response = requests.get(url)
327
+ image = Image.open(BytesIO(response.content)).convert("RGB")
328
+ images.append(image)
329
+
330
+ pipe_input = [tuple(images)]
331
+
332
+ pipe = new_pipe.to(DEVICE, DTYPE)
333
+ var_images = pipe(pipe_input, num_inference_steps=50, num_images_per_prompt=3,).images
334
+
335
+ plt.figure(figsize=(3*FIG_K, 2*FIG_K))
336
+ plt.axis('off')
337
+ plt.legend(framealpha=1)
338
+ plt.rcParams['legend.fontsize'] = 'large'
339
+ for i in range(6):
340
+ if i == 0:
341
+ plt.subplot(2, 3, i+1)
342
+ val_image = img1
343
+ label_str = "A"
344
+ elif i == 1:
345
+ plt.subplot(2, 3, i+1)
346
+ val_image = alt_img1
347
+ label_str = "A*"
348
+ elif i == 2:
349
+ plt.subplot(2, 3, i+1)
350
+ val_image = img2
351
+ label_str = "Target"
352
+ else:
353
+ plt.subplot(2, 3,i + 1)
354
+ val_image = var_images[i-3]
355
+ label_str = f"Variation {i-2}"
356
+
357
+ val_image = ImageOps.expand(val_image,border=2,fill='black')
358
+ plt.imshow(val_image)
359
+ plt.scatter([], [], c="r", label=label_str)
360
+ plt.legend(loc="lower right")
361
+ plt.axis('off')
362
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
363
  ```
364
 
365
  Returns:
366
+ [`~ImagePipelineOutput`] or `tuple`
367
+ The generated image(s) as a [`~ImagePipelineOutput`] or a tuple of images.
 
368
  """
369
 
370
  # 1. Check inputs. Raise error if not correct