Spaces:
Runtime error
Runtime error
File size: 3,332 Bytes
be13417 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import numpy as np
import PIL
import torch
from diffusers.utils.pil_utils import PIL_INTERPOLATION
from PIL import Image
from lavis.common.annotator.canny import CannyDetector
from lavis.common.annotator.util import HWC3, resize_image
apply_canny = CannyDetector()
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def preprocess_canny(
input_image: np.ndarray,
image_resolution: int,
low_threshold: int,
high_threshold: int,
):
image = resize_image(HWC3(input_image), image_resolution)
control_image = apply_canny(image, low_threshold, high_threshold)
control_image = HWC3(control_image)
# vis_control_image = 255 - control_image
# return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
# vis_control_image)
return PIL.Image.fromarray(control_image)
def generate_canny(cond_image_input, low_threshold, high_threshold):
# convert cond_image_input to numpy array
cond_image_input = np.array(cond_image_input).astype(np.uint8)
# canny_input, vis_control_image = preprocess_canny(cond_image_input, 512, low_threshold=100, high_threshold=200)
vis_control_image = preprocess_canny(cond_image_input, 512, low_threshold=low_threshold, high_threshold=high_threshold)
return vis_control_image
def prepare_cond_image(
image, width, height, batch_size, device, do_classifier_free_guidance=True
):
if not isinstance(image, torch.Tensor):
if isinstance(image, Image.Image):
image = [image]
if isinstance(image[0], Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
# repeat_by = num_images_per_prompt
raise NotImplementedError
image = image.repeat_interleave(repeat_by, dim=0)
# image = image.to(device=self.device, dtype=dtype)
image = image.to(device=device)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
|