File size: 8,099 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os
import PIL
import PIL.Image
import PIL.ImageOps
import torch
from typing import Union
import torch.nn.functional as F
import torchvision.transforms as TT
def tokenize_with_trigger_word(tokens, weights, num_images, num_tokens, img_token, start_token=49406, end_token=49407, pad_token=0, max_len=77, return_mask=False):
"""
Filters out the image token(s).
Repeats the preceding token if any.
Rebatches.
"""
count = 0
mask = (tokens != start_token) & (tokens != end_token) & (tokens != pad_token)
clean_tokens, clean_tokens_mask = tokens[mask], weights[mask]
img_token_indices = (clean_tokens == img_token).nonzero().view(-1)
split = torch.tensor_split(clean_tokens, img_token_indices + 1, dim=-1)
split_mask = torch.tensor_split(clean_tokens_mask, img_token_indices + 1, dim=-1)
tt = []
ww = []
for chunk, chunk_mask in zip(split, split_mask):
img_token_exists = chunk == img_token
img_token_not_exists = ~img_token_exists
pad_amount = img_token_exists.nonzero().view(-1).shape[0] * num_images * num_tokens
chunk_clean, chunk_mask_clean = chunk[img_token_not_exists], chunk_mask[img_token_not_exists]
if pad_amount > 0 and len(chunk_clean) > 0:
count += 1
tt.append(torch.nn.functional.pad(chunk_clean[:-1], (0, pad_amount), 'constant', chunk_clean[-1] if not return_mask else -1))
ww.append(torch.nn.functional.pad(chunk_mask_clean[:-1], (0, pad_amount), 'constant', chunk_mask_clean[-1] if not return_mask else -1))
if count == 0:
return (tokens, weights, count)
# rebatch and pad
out = []
outw = []
one = torch.tensor([1.0])
for tc, tcw in zip(torch.cat(tt).split(max_len - 2), torch.cat(ww).split(max_len - 2)):
out.append(torch.cat([torch.tensor([start_token]), tc, torch.tensor([end_token])]))
outw.append(torch.cat([one, tcw, one]))
out = torch.nn.utils.rnn.pad_sequence(out, batch_first=True, padding_value=pad_token)
outw = torch.nn.utils.rnn.pad_sequence(outw, batch_first=True, padding_value=1.0)
out = torch.nn.functional.pad(out, (0, max(0, max_len - out.shape[1])), 'constant', pad_token)
outw = torch.nn.functional.pad(outw, (0, max(0, max_len - outw.shape[1])), 'constant', 1.0)
return (out, outw, count)
def load_pil_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
import requests
img = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image_path = folder_paths.get_annotated_filepath(image)
img = Image.open(image_path)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
return img
# from diffusers.utils import load_image
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
`PIL.Image.Image`:
A PIL Image.
"""
image = load_pil_image(image)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
from PIL import Image, ImageSequence, ImageOps
import numpy as np
import folder_paths
from nodes import LoadImage
class LoadImageCustom(LoadImage):
def load_image(self, image):
# image_path = folder_paths.get_annotated_filepath(image)
# img = Image.open(image_path)
img = load_pil_image(image)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
def crop_image_pil(image, crop_position):
"""
Crop a PIL image based on the specified crop_position.
Parameters:
- image: PIL Image object
- crop_position: One of "top", "bottom", "left", "right", "center", or "pad"
Returns:
- Cropped PIL Image object
"""
width, height = image.size
left, top, right, bottom = 0, 0, width, height
if "pad" in crop_position:
target_length = max(width, height)
pad_l = max((target_length - width) // 2, 0)
pad_t = max((target_length - height) // 2, 0)
return ImageOps.expand(image, border=(pad_l, pad_t, target_length - width - pad_l, target_length - height - pad_t), fill=0)
else:
crop_size = min(width, height)
x = (width - crop_size) // 2
y = (height - crop_size) // 2
if "top" in crop_position:
bottom = top + crop_size
elif "bottom" in crop_position:
top = height - crop_size
bottom = height
elif "left" in crop_position:
right = left + crop_size
elif "right" in crop_position:
left = width - crop_size
right = width
return image.crop((left, top, right, bottom))
def prepImages(images, *args, **kwargs):
to_tensor = TT.ToTensor()
images_ = []
for img in images:
image = to_tensor(img)
if len(image.shape) <= 3: image.unsqueeze_(0)
images_.append(prepImage(image.movedim(1,-1), *args, **kwargs))
return torch.cat(images_)
def prepImage(image, interpolation="LANCZOS", crop_position="center", size=(224,224), sharpening=0.0, padding=0):
_, oh, ow, _ = image.shape
output = image.permute([0,3,1,2])
if "pad" in crop_position:
target_length = max(oh, ow)
pad_l = (target_length - ow) // 2
pad_r = (target_length - ow) - pad_l
pad_t = (target_length - oh) // 2
pad_b = (target_length - oh) - pad_t
output = F.pad(output, (pad_l, pad_r, pad_t, pad_b), value=0, mode="constant")
else:
crop_size = min(oh, ow)
x = (ow-crop_size) // 2
y = (oh-crop_size) // 2
if "top" in crop_position:
y = 0
elif "bottom" in crop_position:
y = oh-crop_size
elif "left" in crop_position:
x = 0
elif "right" in crop_position:
x = ow-crop_size
x2 = x+crop_size
y2 = y+crop_size
# crop
output = output[:, :, y:y2, x:x2]
# resize (apparently PIL resize is better than torchvision interpolate)
imgs = []
to_PIL_image = TT.ToPILImage()
to_tensor = TT.ToTensor()
for i in range(output.shape[0]):
img = to_PIL_image(output[i])
img = img.resize(size, resample=PIL.Image.Resampling[interpolation])
imgs.append(to_tensor(img))
output = torch.stack(imgs, dim=0)
imgs = None # zelous GC
if padding > 0:
output = F.pad(output, (padding, padding, padding, padding), value=255, mode="constant")
output = output.permute([0,2,3,1])
return output
|