Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import numpy as np | |
import PIL | |
import torch | |
from diffusers.utils.pil_utils import PIL_INTERPOLATION | |
from PIL import Image | |
from lavis.common.annotator.canny import CannyDetector | |
from lavis.common.annotator.util import HWC3, resize_image | |
apply_canny = CannyDetector() | |
def numpy_to_pil(images): | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def preprocess_canny( | |
input_image: np.ndarray, | |
image_resolution: int, | |
low_threshold: int, | |
high_threshold: int, | |
): | |
image = resize_image(HWC3(input_image), image_resolution) | |
control_image = apply_canny(image, low_threshold, high_threshold) | |
control_image = HWC3(control_image) | |
# vis_control_image = 255 - control_image | |
# return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
# vis_control_image) | |
return PIL.Image.fromarray(control_image) | |
def generate_canny(cond_image_input, low_threshold, high_threshold): | |
# convert cond_image_input to numpy array | |
cond_image_input = np.array(cond_image_input).astype(np.uint8) | |
# canny_input, vis_control_image = preprocess_canny(cond_image_input, 512, low_threshold=100, high_threshold=200) | |
vis_control_image = preprocess_canny(cond_image_input, 512, low_threshold=low_threshold, high_threshold=high_threshold) | |
return vis_control_image | |
def prepare_cond_image( | |
image, width, height, batch_size, device, do_classifier_free_guidance=True | |
): | |
if not isinstance(image, torch.Tensor): | |
if isinstance(image, Image.Image): | |
image = [image] | |
if isinstance(image[0], Image.Image): | |
images = [] | |
for image_ in image: | |
image_ = image_.convert("RGB") | |
image_ = image_.resize( | |
(width, height), resample=PIL_INTERPOLATION["lanczos"] | |
) | |
image_ = np.array(image_) | |
image_ = image_[None, :] | |
images.append(image_) | |
image = images | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image.transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
# repeat_by = num_images_per_prompt | |
raise NotImplementedError | |
image = image.repeat_interleave(repeat_by, dim=0) | |
# image = image.to(device=self.device, dtype=dtype) | |
image = image.to(device=device) | |
if do_classifier_free_guidance: | |
image = torch.cat([image] * 2) | |
return image | |