|
import os |
|
import PIL |
|
import PIL.Image |
|
import PIL.ImageOps |
|
import torch |
|
from typing import Union |
|
import torch.nn.functional as F |
|
import torchvision.transforms as TT |
|
|
|
def tokenize_with_trigger_word(tokens, weights, num_images, num_tokens, img_token, start_token=49406, end_token=49407, pad_token=0, max_len=77, return_mask=False): |
|
""" |
|
Filters out the image token(s). |
|
Repeats the preceding token if any. |
|
Rebatches. |
|
""" |
|
count = 0 |
|
mask = (tokens != start_token) & (tokens != end_token) & (tokens != pad_token) |
|
clean_tokens, clean_tokens_mask = tokens[mask], weights[mask] |
|
img_token_indices = (clean_tokens == img_token).nonzero().view(-1) |
|
split = torch.tensor_split(clean_tokens, img_token_indices + 1, dim=-1) |
|
split_mask = torch.tensor_split(clean_tokens_mask, img_token_indices + 1, dim=-1) |
|
|
|
tt = [] |
|
ww = [] |
|
for chunk, chunk_mask in zip(split, split_mask): |
|
img_token_exists = chunk == img_token |
|
img_token_not_exists = ~img_token_exists |
|
pad_amount = img_token_exists.nonzero().view(-1).shape[0] * num_images * num_tokens |
|
chunk_clean, chunk_mask_clean = chunk[img_token_not_exists], chunk_mask[img_token_not_exists] |
|
if pad_amount > 0 and len(chunk_clean) > 0: |
|
count += 1 |
|
tt.append(torch.nn.functional.pad(chunk_clean[:-1], (0, pad_amount), 'constant', chunk_clean[-1] if not return_mask else -1)) |
|
ww.append(torch.nn.functional.pad(chunk_mask_clean[:-1], (0, pad_amount), 'constant', chunk_mask_clean[-1] if not return_mask else -1)) |
|
|
|
if count == 0: |
|
return (tokens, weights, count) |
|
|
|
|
|
out = [] |
|
outw = [] |
|
one = torch.tensor([1.0]) |
|
for tc, tcw in zip(torch.cat(tt).split(max_len - 2), torch.cat(ww).split(max_len - 2)): |
|
out.append(torch.cat([torch.tensor([start_token]), tc, torch.tensor([end_token])])) |
|
outw.append(torch.cat([one, tcw, one])) |
|
|
|
out = torch.nn.utils.rnn.pad_sequence(out, batch_first=True, padding_value=pad_token) |
|
outw = torch.nn.utils.rnn.pad_sequence(outw, batch_first=True, padding_value=1.0) |
|
|
|
out = torch.nn.functional.pad(out, (0, max(0, max_len - out.shape[1])), 'constant', pad_token) |
|
outw = torch.nn.functional.pad(outw, (0, max(0, max_len - outw.shape[1])), 'constant', 1.0) |
|
|
|
return (out, outw, count) |
|
|
|
def load_pil_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: |
|
if isinstance(image, str): |
|
if image.startswith("http://") or image.startswith("https://"): |
|
import requests |
|
img = Image.open(requests.get(image, stream=True).raw) |
|
elif os.path.isfile(image): |
|
image_path = folder_paths.get_annotated_filepath(image) |
|
img = Image.open(image_path) |
|
else: |
|
raise ValueError( |
|
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" |
|
) |
|
elif isinstance(image, PIL.Image.Image): |
|
image = image |
|
else: |
|
raise ValueError( |
|
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." |
|
) |
|
return img |
|
|
|
|
|
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: |
|
""" |
|
Loads `image` to a PIL Image. |
|
|
|
Args: |
|
image (`str` or `PIL.Image.Image`): |
|
The image to convert to the PIL Image format. |
|
Returns: |
|
`PIL.Image.Image`: |
|
A PIL Image. |
|
""" |
|
image = load_pil_image(image) |
|
image = PIL.ImageOps.exif_transpose(image) |
|
image = image.convert("RGB") |
|
return image |
|
|
|
from PIL import Image, ImageSequence, ImageOps |
|
import numpy as np |
|
import folder_paths |
|
from nodes import LoadImage |
|
class LoadImageCustom(LoadImage): |
|
def load_image(self, image): |
|
|
|
|
|
img = load_pil_image(image) |
|
output_images = [] |
|
output_masks = [] |
|
for i in ImageSequence.Iterator(img): |
|
i = ImageOps.exif_transpose(i) |
|
if i.mode == 'I': |
|
i = i.point(lambda i: i * (1 / 255)) |
|
image = i.convert("RGB") |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image)[None,] |
|
if 'A' in i.getbands(): |
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 |
|
mask = 1. - torch.from_numpy(mask) |
|
else: |
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") |
|
output_images.append(image) |
|
output_masks.append(mask.unsqueeze(0)) |
|
|
|
if len(output_images) > 1: |
|
output_image = torch.cat(output_images, dim=0) |
|
output_mask = torch.cat(output_masks, dim=0) |
|
else: |
|
output_image = output_images[0] |
|
output_mask = output_masks[0] |
|
|
|
return (output_image, output_mask) |
|
|
|
def crop_image_pil(image, crop_position): |
|
""" |
|
Crop a PIL image based on the specified crop_position. |
|
|
|
Parameters: |
|
- image: PIL Image object |
|
- crop_position: One of "top", "bottom", "left", "right", "center", or "pad" |
|
|
|
Returns: |
|
- Cropped PIL Image object |
|
""" |
|
|
|
width, height = image.size |
|
left, top, right, bottom = 0, 0, width, height |
|
|
|
if "pad" in crop_position: |
|
target_length = max(width, height) |
|
pad_l = max((target_length - width) // 2, 0) |
|
pad_t = max((target_length - height) // 2, 0) |
|
return ImageOps.expand(image, border=(pad_l, pad_t, target_length - width - pad_l, target_length - height - pad_t), fill=0) |
|
else: |
|
crop_size = min(width, height) |
|
x = (width - crop_size) // 2 |
|
y = (height - crop_size) // 2 |
|
|
|
if "top" in crop_position: |
|
bottom = top + crop_size |
|
elif "bottom" in crop_position: |
|
top = height - crop_size |
|
bottom = height |
|
elif "left" in crop_position: |
|
right = left + crop_size |
|
elif "right" in crop_position: |
|
left = width - crop_size |
|
right = width |
|
|
|
return image.crop((left, top, right, bottom)) |
|
|
|
def prepImages(images, *args, **kwargs): |
|
to_tensor = TT.ToTensor() |
|
images_ = [] |
|
for img in images: |
|
image = to_tensor(img) |
|
if len(image.shape) <= 3: image.unsqueeze_(0) |
|
images_.append(prepImage(image.movedim(1,-1), *args, **kwargs)) |
|
return torch.cat(images_) |
|
|
|
def prepImage(image, interpolation="LANCZOS", crop_position="center", size=(224,224), sharpening=0.0, padding=0): |
|
_, oh, ow, _ = image.shape |
|
output = image.permute([0,3,1,2]) |
|
|
|
if "pad" in crop_position: |
|
target_length = max(oh, ow) |
|
pad_l = (target_length - ow) // 2 |
|
pad_r = (target_length - ow) - pad_l |
|
pad_t = (target_length - oh) // 2 |
|
pad_b = (target_length - oh) - pad_t |
|
output = F.pad(output, (pad_l, pad_r, pad_t, pad_b), value=0, mode="constant") |
|
else: |
|
crop_size = min(oh, ow) |
|
x = (ow-crop_size) // 2 |
|
y = (oh-crop_size) // 2 |
|
if "top" in crop_position: |
|
y = 0 |
|
elif "bottom" in crop_position: |
|
y = oh-crop_size |
|
elif "left" in crop_position: |
|
x = 0 |
|
elif "right" in crop_position: |
|
x = ow-crop_size |
|
|
|
x2 = x+crop_size |
|
y2 = y+crop_size |
|
|
|
|
|
output = output[:, :, y:y2, x:x2] |
|
|
|
|
|
imgs = [] |
|
to_PIL_image = TT.ToPILImage() |
|
to_tensor = TT.ToTensor() |
|
for i in range(output.shape[0]): |
|
img = to_PIL_image(output[i]) |
|
img = img.resize(size, resample=PIL.Image.Resampling[interpolation]) |
|
imgs.append(to_tensor(img)) |
|
output = torch.stack(imgs, dim=0) |
|
|
|
imgs = None |
|
|
|
if padding > 0: |
|
output = F.pad(output, (padding, padding, padding, padding), value=255, mode="constant") |
|
|
|
output = output.permute([0,2,3,1]) |
|
|
|
return output |
|
|