bardofcodes
commited on
Update pipeline.py
Browse files- pipeline.py +70 -32
pipeline.py
CHANGED
@@ -18,12 +18,7 @@ import einops
|
|
18 |
import PIL.Image
|
19 |
import numpy as np
|
20 |
import torch as th
|
21 |
-
import torch.nn as nn
|
22 |
-
from torchvision import transforms
|
23 |
|
24 |
-
from diffusers import ModelMixin
|
25 |
-
from transformers import AutoModel, AutoConfig, SiglipVisionConfig, Dinov2Config, Dinov2Model
|
26 |
-
from transformers import SiglipVisionModel
|
27 |
from diffusers import DiffusionPipeline
|
28 |
from diffusers.image_processor import VaeImageProcessor
|
29 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
@@ -31,8 +26,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
31 |
from diffusers.utils.torch_utils import randn_tensor
|
32 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
33 |
|
34 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
35 |
-
# REf: https://github.com/tatp22/multidim-positional-encoding/tree/master
|
36 |
from analogy_encoder import AnalogyEncoder
|
37 |
from analogy_projector import AnalogyProjector
|
38 |
from analogy_input_processor import AnalogyInputProcessor
|
@@ -259,8 +252,8 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
|
|
259 |
The call function to the pipeline for generation.
|
260 |
|
261 |
Args:
|
262 |
-
|
263 |
-
The
|
264 |
height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
265 |
The height in pixels of the generated image.
|
266 |
width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
@@ -301,32 +294,77 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
|
|
301 |
Examples:
|
302 |
|
303 |
```py
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
```
|
325 |
|
326 |
Returns:
|
327 |
-
[`~
|
328 |
-
|
329 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images.
|
330 |
"""
|
331 |
|
332 |
# 1. Check inputs. Raise error if not correct
|
|
|
18 |
import PIL.Image
|
19 |
import numpy as np
|
20 |
import torch as th
|
|
|
|
|
21 |
|
|
|
|
|
|
|
22 |
from diffusers import DiffusionPipeline
|
23 |
from diffusers.image_processor import VaeImageProcessor
|
24 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
|
26 |
from diffusers.utils.torch_utils import randn_tensor
|
27 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
28 |
|
|
|
|
|
29 |
from analogy_encoder import AnalogyEncoder
|
30 |
from analogy_projector import AnalogyProjector
|
31 |
from analogy_input_processor import AnalogyInputProcessor
|
|
|
252 |
The call function to the pipeline for generation.
|
253 |
|
254 |
Args:
|
255 |
+
analogy_prompt (`List[Tuple[PIL.Image.Image]]'):
|
256 |
+
The analogy sequence A, A*, B which is our model's prompt for generating B* the analogical pattern satisfying A:A*::B:B*.
|
257 |
height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
258 |
The height in pixels of the generated image.
|
259 |
width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
|
|
294 |
Examples:
|
295 |
|
296 |
```py
|
297 |
+
import requests
|
298 |
+
import torch as th
|
299 |
+
from PIL import Image
|
300 |
+
from io import BytesIO
|
301 |
+
import matplotlib.pyplot as plt
|
302 |
+
from PIL import Image, ImageOps
|
303 |
+
from diffusers import DiffusionPipeline
|
304 |
+
|
305 |
+
SEED = 1729
|
306 |
+
DEVICE = th.device("cuda")
|
307 |
+
DTYPE = th.float16
|
308 |
+
FIG_K = 3
|
309 |
+
EXAMPLE_ID = 0
|
310 |
+
|
311 |
+
# Now we need to do the trick
|
312 |
+
pretrained_path = "bardofcodes/pattern_analogies"
|
313 |
+
new_pipe = DiffusionPipeline.from_pretrained(
|
314 |
+
pretrained_path,
|
315 |
+
custom_pipeline=pretrained_path,
|
316 |
+
trust_remote_code=True
|
317 |
+
)
|
318 |
+
|
319 |
+
img_urls = [
|
320 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a.png",
|
321 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a_star.png",
|
322 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_b.png",
|
323 |
+
]
|
324 |
+
images = []
|
325 |
+
for url in img_urls:
|
326 |
+
response = requests.get(url)
|
327 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
328 |
+
images.append(image)
|
329 |
+
|
330 |
+
pipe_input = [tuple(images)]
|
331 |
+
|
332 |
+
pipe = new_pipe.to(DEVICE, DTYPE)
|
333 |
+
var_images = pipe(pipe_input, num_inference_steps=50, num_images_per_prompt=3,).images
|
334 |
+
|
335 |
+
plt.figure(figsize=(3*FIG_K, 2*FIG_K))
|
336 |
+
plt.axis('off')
|
337 |
+
plt.legend(framealpha=1)
|
338 |
+
plt.rcParams['legend.fontsize'] = 'large'
|
339 |
+
for i in range(6):
|
340 |
+
if i == 0:
|
341 |
+
plt.subplot(2, 3, i+1)
|
342 |
+
val_image = img1
|
343 |
+
label_str = "A"
|
344 |
+
elif i == 1:
|
345 |
+
plt.subplot(2, 3, i+1)
|
346 |
+
val_image = alt_img1
|
347 |
+
label_str = "A*"
|
348 |
+
elif i == 2:
|
349 |
+
plt.subplot(2, 3, i+1)
|
350 |
+
val_image = img2
|
351 |
+
label_str = "Target"
|
352 |
+
else:
|
353 |
+
plt.subplot(2, 3,i + 1)
|
354 |
+
val_image = var_images[i-3]
|
355 |
+
label_str = f"Variation {i-2}"
|
356 |
+
|
357 |
+
val_image = ImageOps.expand(val_image,border=2,fill='black')
|
358 |
+
plt.imshow(val_image)
|
359 |
+
plt.scatter([], [], c="r", label=label_str)
|
360 |
+
plt.legend(loc="lower right")
|
361 |
+
plt.axis('off')
|
362 |
+
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
363 |
```
|
364 |
|
365 |
Returns:
|
366 |
+
[`~ImagePipelineOutput`] or `tuple`
|
367 |
+
The generated image(s) as a [`~ImagePipelineOutput`] or a tuple of images.
|
|
|
368 |
"""
|
369 |
|
370 |
# 1. Check inputs. Raise error if not correct
|